feat(vla): align transformer training stack and rollout validation

This commit is contained in:
Logic
2026-03-31 15:39:20 +08:00
parent 424c265823
commit d84bc6876e
25 changed files with 4043 additions and 706 deletions

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,88 @@
import pickle
import tempfile
import unittest
from pathlib import Path
import h5py
import numpy as np
from roboimi.vla.scripts import calculate_stats
class CalculateStatsCliTest(unittest.TestCase):
def test_default_dataset_dir_is_absolute_and_package_relative(self):
expected = (
Path(calculate_stats.__file__).resolve().parents[2]
/ "demos"
/ "dataset"
/ "sim_transfer"
)
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
episode_path = dataset_dir / "episode_0.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset(
"action",
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
)
observations = root.create_group("observations")
observations.create_dataset(
"qpos",
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
)
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
stats_path = dataset_dir / "dataset_stats.pkl"
self.assertTrue(stats_path.exists())
with stats_path.open("rb") as f:
stats = pickle.load(f)
self.assertEqual(
set(stats),
{
"action_mean",
"action_std",
"action_min",
"action_max",
"qpos_mean",
"qpos_std",
"qpos_min",
"qpos_max",
},
)
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
def test_main_raises_clear_error_for_empty_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
with self.assertRaisesRegex(
ValueError, r"No episode_\*\.hdf5 files found"
) as ctx:
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
self.assertIn(str(dataset_dir), str(ctx.exception))
def test_main_raises_clear_error_for_missing_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir) / "missing"
with self.assertRaisesRegex(
ValueError, r"No episode_\*\.hdf5 files found"
) as ctx:
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
self.assertIn(str(dataset_dir), str(ctx.exception))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,28 @@
import unittest
from roboimi.vla.eval_utils import execute_policy_action
class _FakeEnv:
def __init__(self):
self.calls = []
def step(self, action):
self.calls.append(("step", action))
def step_jnt(self, action):
self.calls.append(("step_jnt", action))
class EvalVLAExecutionTest(unittest.TestCase):
def test_execute_policy_action_uses_ee_step(self):
env = _FakeEnv()
action = [1, 2, 3]
execute_policy_action(env, action)
self.assertEqual(env.calls, [("step", action)])
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,259 @@
import unittest
from pathlib import Path
from unittest import mock
import numpy as np
import torch
from omegaconf import OmegaConf
from roboimi.demos.vla_scripts import eval_vla
from roboimi.envs.double_base import DualDianaMed
from roboimi.envs.double_pos_ctrl_env import make_sim_env
class _FakeAgent:
def __init__(self):
self.reset_calls = 0
self.last_observation = None
def eval(self):
return self
def to(self, _device):
return self
def reset(self):
self.reset_calls += 1
def select_action(self, observation):
self.last_observation = observation
return torch.zeros(16)
class _FakeEnv:
def __init__(self):
self.image_obs_calls = 0
self.render_calls = 0
self.reset_calls = []
def reset(self, box_pos):
self.reset_calls.append(np.array(box_pos))
def _get_image_obs(self):
self.image_obs_calls += 1
return {
"images": {
"front": np.zeros((8, 8, 3), dtype=np.uint8),
}
}
def _get_qpos_obs(self):
return {"qpos": np.zeros(16, dtype=np.float32)}
def render(self):
self.render_calls += 1
raise AssertionError("env.render() should be skipped when eval.headless=true")
class _RewardTrackingEnv(_FakeEnv):
def __init__(self, reward_sequences):
super().__init__()
self.reward_sequences = reward_sequences
self.episode_index = -1
self.step_index = 0
self.rew = 0.0
def reset(self, box_pos):
super().reset(box_pos)
self.episode_index += 1
self.step_index = 0
class _FakeRenderer:
def __init__(self, env):
self._env = env
self._frames = [
np.full((4, 4, 3), fill_value=index, dtype=np.uint8)
for index in range(5)
]
self._index = 0
def update_scene(self, _mj_data, camera=None):
self._camera = camera
def render(self):
frame = self._frames[self._index]
self._index += 1
if self._index >= len(self._frames):
self._env.exit_flag = True
return frame
class EvalVLAHeadlessTest(unittest.TestCase):
def test_eval_config_exposes_headless_default(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
self.assertIn("headless", eval_cfg)
self.assertFalse(eval_cfg.headless)
def test_make_sim_env_accepts_headless_and_disables_render(self):
fake_env = object()
with mock.patch(
"roboimi.assets.robots.diana_med.BiDianaMed",
return_value="robot",
), mock.patch(
"roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl",
return_value=fake_env,
) as env_cls:
env = make_sim_env("sim_transfer", headless=True)
self.assertIs(env, fake_env)
env_cls.assert_called_once_with(
robot="robot",
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
)
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
env = DualDianaMed.__new__(DualDianaMed)
env.mj_model = object()
env.mj_data = object()
env.exit_flag = False
env.is_render = False
env.cam = "angle"
env.r_vis = None
env.l_vis = None
env.top = None
env.angle = None
env.front = None
with mock.patch(
"roboimi.envs.double_base.mj.Renderer",
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch(
"roboimi.envs.double_base.cv2.imshow"
) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key:
env.camera_viewer()
named_window.assert_not_called()
imshow.assert_not_called()
wait_key.assert_not_called()
self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle)
self.assertIsNotNone(env.front)
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
fake_env = _FakeEnv()
fake_agent = _FakeAgent()
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 1,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
) as make_env, mock.patch.object(
eval_vla,
"sample_transfer_pose",
return_value=np.array([0.1, 0.2, 0.3]),
), mock.patch.object(
eval_vla,
"execute_policy_action",
) as execute_policy_action, mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
eval_vla.main.__wrapped__(cfg)
make_env.assert_called_once_with("sim_transfer", headless=True)
execute_policy_action.assert_called_once()
self.assertEqual(fake_env.image_obs_calls, 1)
self.assertEqual(fake_env.render_calls, 0)
self.assertIsNotNone(fake_agent.last_observation)
self.assertIn("front", fake_agent.last_observation["images"])
def test_run_eval_returns_average_reward_summary(self):
reward_sequences = [
[1.0, 2.0],
[0.5, 4.0],
]
fake_env = _RewardTrackingEnv(reward_sequences)
fake_agent = _FakeAgent()
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 2,
"max_timesteps": 2,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
def fake_execute_policy_action(env, action):
del action
env.rew = env.reward_sequences[env.episode_index][env.step_index]
env.step_index += 1
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
), mock.patch.object(
eval_vla,
"sample_transfer_pose",
return_value=np.array([0.1, 0.2, 0.3]),
), mock.patch.object(
eval_vla,
"execute_policy_action",
side_effect=fake_execute_policy_action,
), mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
summary = eval_vla._run_eval(cfg)
self.assertEqual(summary["episode_rewards"], [3.0, 4.5])
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,387 @@
import contextlib
import sys
import types
import unittest
from pathlib import Path
import torch
from hydra import compose, initialize_config_dir
from hydra.errors import InstantiationException
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
_REPO_ROOT = Path(__file__).resolve().parents[1]
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
_MISSING = object()
class _FakeScheduler:
def __init__(self, num_train_timesteps=100, **kwargs):
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
self.timesteps = []
def add_noise(self, sample, noise, timestep):
return sample + noise
def set_timesteps(self, num_inference_steps):
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
def step(self, noise_pred, timestep, sample):
return types.SimpleNamespace(prev_sample=sample)
class _IdentityCrop:
def __init__(self, size):
self.size = size
def __call__(self, x):
return x
class _FakeResNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
self.relu2 = torch.nn.ReLU()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.fc = torch.nn.Linear(16, 16)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.avgpool(x)
x = torch.flatten(x, start_dim=1)
return self.fc(x)
class _FakeRearrange(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x
class _CondCapturingHead(torch.nn.Module):
def __init__(self):
super().__init__()
self.last_cond = None
def forward(self, sample, timestep, cond):
self.last_cond = cond.detach().clone()
return torch.zeros_like(sample)
@contextlib.contextmanager
def _stub_optional_modules():
previous_modules = {}
def inject(name, module):
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
sys.modules[name] = module
diffusers_module = types.ModuleType('diffusers')
schedulers_module = types.ModuleType('diffusers.schedulers')
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
ddpm_module.DDPMScheduler = _FakeScheduler
ddim_module.DDIMScheduler = _FakeScheduler
diffusers_module.DDPMScheduler = _FakeScheduler
diffusers_module.DDIMScheduler = _FakeScheduler
diffusers_module.schedulers = schedulers_module
schedulers_module.scheduling_ddpm = ddpm_module
schedulers_module.scheduling_ddim = ddim_module
torchvision_module = types.ModuleType('torchvision')
models_module = types.ModuleType('torchvision.models')
transforms_module = types.ModuleType('torchvision.transforms')
models_module.resnet18 = lambda weights=None: _FakeResNet()
transforms_module.CenterCrop = _IdentityCrop
transforms_module.RandomCrop = _IdentityCrop
torchvision_module.models = models_module
torchvision_module.transforms = transforms_module
einops_module = types.ModuleType('einops')
einops_module.rearrange = lambda x, *args, **kwargs: x
einops_layers_module = types.ModuleType('einops.layers')
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
einops_layers_torch_module.Rearrange = _FakeRearrange
einops_module.layers = einops_layers_module
einops_layers_module.torch = einops_layers_torch_module
try:
inject('diffusers', diffusers_module)
inject('diffusers.schedulers', schedulers_module)
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
inject('torchvision', torchvision_module)
inject('torchvision.models', models_module)
inject('torchvision.transforms', transforms_module)
inject('einops', einops_module)
inject('einops.layers', einops_layers_module)
inject('einops.layers.torch', einops_layers_torch_module)
yield
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
def _compose_cfg(overrides=None):
if not OmegaConf.has_resolver('len'):
OmegaConf.register_new_resolver('len', lambda x: len(x))
GlobalHydra.instance().clear()
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
return compose(config_name='config', overrides=list(overrides or []))
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
channels, height, width = image_shape
per_camera_fill = per_camera_fill or {
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
}
return {
name: torch.full(
(batch_size, obs_horizon, channels, height, width),
fill_value=fill_value,
dtype=torch.float32,
)
for name, fill_value in per_camera_fill.items()
}
def _patch_backbone_for_order_tracking(backbone):
feature_dim = backbone.output_dim
def encode_mean(image_batch):
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
return mean_feature.repeat(1, feature_dim)
if backbone.use_separate_rgb_encoder_per_camera:
for encoder in backbone.rgb_encoder:
encoder.forward_single_image = encode_mean
else:
backbone.rgb_encoder.forward_single_image = encode_mean
def _extract_camera_markers(cond, feature_dim, num_cams):
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
return camera_block[:, 0]
class ResNetTransformerAgentWiringTest(unittest.TestCase):
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.inference_steps=1',
'agent.head.n_layer=1',
'agent.head.n_cond_layers=0',
'agent.head.n_emb=32',
'agent.head.n_head=4',
]
)
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(cfg.agent.head_type, 'transformer')
self.assertEqual(cfg.agent.num_cams, 3)
self.assertTrue(cfg.agent.head.obs_as_cond)
self.assertFalse(cfg.agent.head.causal_attn)
with _stub_optional_modules():
agent = instantiate(cfg.agent)
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
batch_size = 2
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
images = _make_images(
batch_size,
cfg.agent.obs_horizon,
image_shape,
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
'left_wrist': 99.0,
},
)
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
_patch_backbone_for_order_tracking(agent.vision_encoder)
capturing_head = _CondCapturingHead()
agent.noise_pred_net = capturing_head
predicted_actions = agent.predict_action(images, proprioception)
self.assertEqual(
predicted_actions.shape,
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
)
self.assertIsNotNone(capturing_head.last_cond)
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
camera_markers = _extract_camera_markers(
capturing_head.last_cond,
agent.vision_encoder.output_dim,
agent.num_cams,
)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
missing_images = dict(images)
missing_images.pop('top')
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
agent.predict_action(missing_images, proprioception)
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
with _stub_optional_modules():
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
instantiate(cfg.agent)
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'top': 20.0,
'front': 30.0,
'r_vis': 10.0,
},
)
ordered_features = backbone(images)
camera_markers = _extract_camera_markers(
ordered_features,
backbone.output_dim,
len(images),
)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
agent = instantiate(cfg.agent)
observation = {
'qpos': torch.randn(cfg.agent.obs_dim),
'images': {
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
},
}
agent._populate_queues(observation)
batch = agent._prepare_observation_batch()
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'r_vis': 10.0,
},
)
with self.assertRaisesRegex(ValueError, 'num_cameras'):
backbone(images)
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.inference_steps=1',
'agent.head.n_layer=1',
'agent.head.n_cond_layers=0',
'agent.head.n_emb=32',
'agent.head.n_head=4',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
agent = instantiate(cfg.agent)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'r_vis': 10.0,
},
)
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
with self.assertRaisesRegex(ValueError, 'num_cams'):
agent.predict_action(images, proprioception)
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
cfg.agent.num_cams = 2
cfg.agent.vision_backbone.num_cameras = 3
with _stub_optional_modules():
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
instantiate(cfg.agent)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,63 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from roboimi.assets.robots.diana_med import BiDianaMed
class _FakeKDL:
init_calls = []
reset_calls = []
def __init__(self, urdf_path):
self.__class__.init_calls.append(urdf_path)
def resetChain(self, base, end):
self.__class__.reset_calls.append((base, end))
class RobotAssetPathResolutionTest(unittest.TestCase):
def setUp(self):
_FakeKDL.init_calls = []
_FakeKDL.reset_calls = []
def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self):
repo_root = Path(__file__).resolve().parents[1]
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml'
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
xml_calls = []
def fake_from_xml_path(*, filename, assets=None):
xml_calls.append((filename, assets))
return object()
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch(
'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path',
side_effect=fake_from_xml_path,
), mock.patch(
'roboimi.assets.robots.arm_base.mujoco.MjData',
return_value=object(),
), mock.patch(
'roboimi.assets.robots.arm_base.KDL_utils',
_FakeKDL,
):
BiDianaMed()
finally:
os.chdir(previous_cwd)
self.assertEqual(len(xml_calls), 1)
self.assertEqual(Path(xml_calls[0][0]), expected_xml)
self.assertTrue(Path(xml_calls[0][0]).is_absolute())
self.assertGreaterEqual(len(_FakeKDL.init_calls), 2)
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,58 @@
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import h5py
import numpy as np
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None:
episode_path = dataset_dir / "episode_0.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
root.create_dataset(
"observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4),
)
root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset(
"observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
)
def test_getitem_only_resizes_observation_horizon_images(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
)
resize_calls = []
def fake_resize(image, size, interpolation=None):
resize_calls.append(
{
"shape": tuple(image.shape),
"size": size,
"interpolation": interpolation,
}
)
return image
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
sample = dataset[1]
self.assertEqual(len(resize_calls), 2)
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))

View File

@@ -0,0 +1,779 @@
import os
import tempfile
import unittest
from copy import deepcopy
from pathlib import Path
from unittest import mock
import numpy as np
import torch
from omegaconf import OmegaConf
from torch import nn
from roboimi.demos.vla_scripts import eval_vla, train_vla
class _FakeDataset:
def __len__(self):
return 4
class _FakeLoader:
def __init__(self, batch, length=1):
self._batches = [batch] * length
def __len__(self):
return len(self._batches)
def __iter__(self):
return iter(self._batches)
class _FakeOptimizer:
def __init__(self, lr=1e-3):
self.param_groups = [{'lr': lr}]
def zero_grad(self):
return None
def step(self):
return None
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
del state_dict
return None
class _FakeScheduler:
def __init__(self):
self.step_calls = 0
def step(self):
self.step_calls += 1
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
del state_dict
return None
class _FakeProgressBar:
def __init__(self, iterable):
self._items = list(iterable)
self.postfix_calls = []
def __iter__(self):
return iter(self._items)
def set_postfix(self, values):
self.postfix_calls.append(values)
class _FakeAgent(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(0.0))
def to(self, device):
del device
return self
def compute_loss(self, agent_input):
del agent_input
return (self.weight - torch.tensor(0.5)).pow(2)
def get_normalization_stats(self):
return {}
class _SequentialLossAgent(nn.Module):
def __init__(self, losses):
super().__init__()
self.weight = nn.Parameter(torch.tensor(0.0))
self._losses = list(losses)
self._index = 0
def to(self, device):
del device
return self
def compute_loss(self, agent_input):
del agent_input
loss_value = self._losses[self._index]
self._index += 1
return (self.weight * 0) + torch.tensor(float(loss_value))
def get_normalization_stats(self):
return {}
class _FakeEvalAgent:
def __init__(self):
self.reset_calls = 0
def eval(self):
return self
def to(self, device):
del device
return self
def reset(self):
self.reset_calls += 1
def select_action(self, observation):
del observation
return torch.zeros(2)
class _FakeEvalEnv:
def reset(self, box_pos):
self.box_pos = box_pos
def _get_image_obs(self):
return {
'images': {
'front': np.zeros((8, 8, 3), dtype=np.uint8),
}
}
def _get_qpos_obs(self):
return {'qpos': np.zeros(4, dtype=np.float32)}
def render(self):
raise AssertionError('render should not be called in this helper delegation test')
class TrainVLARolloutValidationTest(unittest.TestCase):
def test_default_train_config_uses_full_dataset_and_epoch_rollout_validation(self):
cfg = OmegaConf.load(Path('roboimi/vla/conf/config.yaml'))
self.assertEqual(cfg.train.val_split, 0.0)
self.assertGreater(cfg.train.batch_size, 8)
self.assertGreater(float(cfg.train.lr), 5e-5)
self.assertGreater(cfg.train.num_workers, 8)
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
def test_eval_main_delegates_to_plain_run_eval_helper(self):
cfg = OmegaConf.create(
{
'agent': {},
'eval': {
'ckpt_path': 'checkpoints/vla_model_step_1.pt',
'num_episodes': 1,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': True,
},
}
)
run_eval_mock = mock.Mock()
with mock.patch.object(eval_vla, '_run_eval', run_eval_mock, create=True), \
mock.patch.object(eval_vla, 'load_checkpoint', return_value=(_FakeEvalAgent(), None)), \
mock.patch.object(eval_vla, 'make_sim_env', return_value=_FakeEvalEnv()), \
mock.patch.object(eval_vla, 'sample_transfer_pose', return_value=np.zeros(3)), \
mock.patch.object(eval_vla, 'execute_policy_action'), \
mock.patch.object(eval_vla, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
eval_vla.main.__wrapped__(cfg)
run_eval_mock.assert_called_once_with(cfg)
def test_run_training_rollout_validation_runs_every_50_epochs_and_uses_avg_reward_metric(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 1,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 100,
'log_freq': 1,
'save_freq': 1000,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 50,
'rollout_num_episodes': 3,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
agent = _FakeAgent()
rollout_mock = mock.Mock(side_effect=[{'avg_reward': 2.0}, {'avg_reward': 1.0}])
swanlab_log_mock = mock.Mock()
saved_checkpoints = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del shuffle, _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
def fake_torch_save(payload, path):
saved_checkpoints.append((str(path), deepcopy(payload)))
return None
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla, '_log_to_swanlab', swanlab_log_mock), \
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True), \
mock.patch.object(eval_vla.main, '__wrapped__', side_effect=AssertionError('training hook should call eval_vla._run_eval')):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(rollout_mock.call_count, 2)
first_rollout_cfg = rollout_mock.call_args_list[0].args[0]
second_rollout_cfg = rollout_mock.call_args_list[1].args[0]
self.assertEqual(first_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_49.pt')
self.assertEqual(second_rollout_cfg.eval.ckpt_path, 'checkpoints/vla_model_step_99.pt')
self.assertEqual(first_rollout_cfg.eval.num_episodes, 3)
self.assertTrue(first_rollout_cfg.eval.headless)
self.assertEqual(first_rollout_cfg.eval.device, 'cpu')
self.assertFalse(first_rollout_cfg.eval.verbose_action)
self.assertEqual(cfg.eval.ckpt_path, 'unused.pt')
self.assertEqual(cfg.eval.num_episodes, 99)
self.assertFalse(cfg.eval.headless)
self.assertEqual(cfg.eval.device, 'cpu')
self.assertFalse(cfg.eval.verbose_action)
rollout_reward_logs = [
call.args[1]['rollout/avg_reward']
for call in swanlab_log_mock.call_args_list
if len(call.args) >= 2 and 'rollout/avg_reward' in call.args[1]
]
self.assertEqual(rollout_reward_logs, [2.0, 1.0])
best_model_saves = [
payload for path, payload in saved_checkpoints
if path.endswith('checkpoints/vla_model_best.pt')
]
self.assertEqual(len(best_model_saves), 1)
self.assertEqual(best_model_saves[0]['rollout_avg_reward'], 2.0)
def test_run_training_keeps_loss_based_best_checkpoint_until_first_rollout_metric_exists(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 1,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 5,
'log_freq': 1,
'save_freq': 2,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 50,
'rollout_num_episodes': 3,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
saved_checkpoints = []
rollout_mock = mock.Mock()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del shuffle, _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=5,
)
def fake_torch_save(payload, path):
saved_checkpoints.append((str(path), deepcopy(payload)))
return None
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(rollout_mock.call_count, 0)
best_model_saves = [
payload for path, payload in saved_checkpoints
if path.endswith('checkpoints/vla_model_best.pt')
]
self.assertEqual(len(best_model_saves), 1)
self.assertIsNone(best_model_saves[0]['rollout_avg_reward'])
def test_run_training_disables_drop_last_when_train_set_is_smaller_than_batch_size(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 8,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 1,
'log_freq': 1,
'save_freq': 10,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 50,
'rollout_num_episodes': 3,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
dataloader_calls = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(dataset, *, shuffle, drop_last, **_kwargs):
dataloader_calls.append({
'shuffle': shuffle,
'drop_last': drop_last,
'dataset_len': len(dataset),
})
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', return_value=None):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
train_loader_calls = [call for call in dataloader_calls if call['shuffle']]
self.assertEqual(len(train_loader_calls), 1)
self.assertFalse(train_loader_calls[0]['drop_last'])
def test_run_training_disables_persistent_workers_for_train_and_val_loaders(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 2,
'num_workers': 2,
'val_split': 0.25,
'seed': 0,
'lr': 1e-3,
'max_steps': 1,
'log_freq': 1,
'save_freq': 10,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 50,
'rollout_num_episodes': 3,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
dataloader_calls = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, persistent_workers, num_workers, **_kwargs):
dataloader_calls.append({
'shuffle': shuffle,
'num_workers': num_workers,
'persistent_workers': persistent_workers,
})
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', return_value=None):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(len(dataloader_calls), 2)
self.assertEqual([call['shuffle'] for call in dataloader_calls], [True, False])
self.assertTrue(all(call['num_workers'] == 2 for call in dataloader_calls))
self.assertTrue(all(call['persistent_workers'] is False for call in dataloader_calls))
def test_run_training_uses_loss_best_until_first_rollout_then_prefers_rollout_reward(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 1,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 6,
'log_freq': 1,
'save_freq': 1,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 2,
'rollout_num_episodes': 1,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
agent = _SequentialLossAgent([10, 9, 8, 7, 6, 5])
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
saved_checkpoints = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=2 if shuffle else 1,
)
def fake_torch_save(payload, path):
saved_checkpoints.append((str(path), deepcopy(payload)))
return None
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
best_model_saves = [
(payload['step'], payload['rollout_avg_reward'])
for path, payload in saved_checkpoints
if path.endswith('checkpoints/vla_model_best.pt')
]
self.assertEqual(
best_model_saves,
[
(1, None),
(2, None),
(3, None),
(3, 1.0),
],
)
self.assertEqual(rollout_mock.call_count, 1)
def test_run_training_keeps_tiny_train_dataset_batch_when_batch_size_is_larger(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 8,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 1,
'log_freq': 1,
'save_freq': 1000,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 0,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
}
)
agent = _FakeAgent()
dataloader_calls = []
saved_checkpoints = []
class _TinyDataset:
def __len__(self):
return 1
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _TinyDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(dataset, *, drop_last, shuffle, **_kwargs):
del _kwargs
dataloader_calls.append(
{
'shuffle': shuffle,
'drop_last': drop_last,
'dataset_len': len(dataset),
}
)
loader_length = 0 if drop_last and len(dataset) < cfg.train.batch_size else 1
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=loader_length,
)
def fake_torch_save(payload, path):
saved_checkpoints.append((str(path), deepcopy(payload)))
return None
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', side_effect=fake_torch_save):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(
dataloader_calls[0],
{
'shuffle': True,
'drop_last': False,
'dataset_len': 1,
},
)
self.assertEqual(
[path for path, _payload in saved_checkpoints],
['checkpoints/vla_model_final.pt'],
)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,699 @@
import importlib
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
_CONFIG_PATH = _REPO_ROOT / 'roboimi/vla/conf/config.yaml'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
def _to_attrdict(value):
if isinstance(value, dict):
return AttrDict({key: _to_attrdict(item) for key, item in value.items()})
if isinstance(value, list):
return [_to_attrdict(item) for item in value]
return value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __init__(self, batch):
self.batch = batch
def __len__(self):
return 1
def __iter__(self):
return iter((self.batch,))
class FakeScheduler:
def __init__(self):
self.step_calls = 0
def step(self):
self.step_calls += 1
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class FakeOptimizer:
def __init__(self, lr=1e-3):
self.param_groups = [{'lr': lr}]
self.loaded_state_dict = None
def zero_grad(self):
return None
def step(self):
return None
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
self.loaded_state_dict = state_dict
return None
class FakeProgressBar:
def __init__(self, iterable):
self._items = list(iterable)
self.postfix_calls = []
def __iter__(self):
return iter(self._items)
def set_postfix(self, values):
self.postfix_calls.append(values)
class FakeAgent(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.tensor(0.0))
def to(self, device):
return self
def compute_loss(self, agent_input):
del agent_input
target = torch.tensor(0.25 if self.training else 0.1)
return (self.weight - target).pow(2)
def get_normalization_stats(self):
return {}
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None):
self.init_error = init_error
self.log_errors = list(log_errors or [])
self.finish_error = finish_error
self.init_calls = []
self.log_calls = []
self.finish_calls = 0
def init(self, project, experiment_name=None, config=None):
self.init_calls.append({
'project': project,
'experiment_name': experiment_name,
'config': config,
})
if self.init_error is not None:
raise self.init_error
return object()
def log(self, payload, step=None):
self.log_calls.append((dict(payload), step))
if self.log_errors:
raise self.log_errors.pop(0)
def finish(self):
self.finish_calls += 1
if self.finish_error is not None:
raise self.finish_error
class TrainVLASwanLabLoggingTest(unittest.TestCase):
def test_default_config_keeps_swanlab_opt_in(self):
config_text = _CONFIG_PATH.read_text(encoding='utf-8')
self.assertIn('use_swanlab: false', config_text)
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
@staticmethod
def to_container(cfg, resolve=False):
del resolve
return dict(cfg)
@staticmethod
def create(cfg):
return _to_attrdict(cfg)
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_swanlab_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self, *, use_swanlab=True, swanlab_run_name='smoke-run'):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0.25,
seed=0,
lr=1e-3,
max_steps=2,
log_freq=1,
save_freq=1,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.0,
pretrained_ckpt=None,
resume_ckpt=None,
use_swanlab=use_swanlab,
swanlab_project='roboimi-vla-tests',
swanlab_run_name=swanlab_run_name,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
eval=AttrDict(
ckpt_path='unused.pt',
num_episodes=1,
max_timesteps=1,
device='cpu',
task_name='sim_transfer',
camera_names=('front',),
use_smoothing=False,
smooth_alpha=0.3,
verbose_action=False,
headless=False,
),
)
def _get_run_training(self, module):
run_training = getattr(module, '_run_training', None)
self.assertIsNotNone(run_training, 'Expected train_vla.py to expose a _run_training(cfg) helper')
return run_training
def _make_batch(self):
return {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(
fake_swanlab.init_calls,
[{
'project': 'roboimi-vla-tests',
'experiment_name': 'smoke-run',
'config': {
'train': {
'device': 'cpu',
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
'log_freq': 1,
'save_freq': 1,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': True,
'swanlab_project': 'roboimi-vla-tests',
'swanlab_run_name': 'smoke-run',
},
'data': {
'camera_names': ('front',),
},
'agent': {
'_target_': 'fake.agent',
},
},
}],
)
logged_keys = set().union(*(payload.keys() for payload, _step in fake_swanlab.log_calls))
self.assertTrue(
{
'train/loss',
'train/lr',
'train/best_loss',
'train/step',
'val/loss',
'final/checkpoint_path',
'final/best_checkpoint_path',
}.issubset(logged_keys)
)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/checkpoint_path'], 'checkpoints/vla_model_final.pt')
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
agent = FakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=AssertionError('swanlab import should not run')):
run_training(cfg)
finally:
os.chdir(previous_cwd)
def test_run_training_finishes_swanlab_when_exception_happens_after_init(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=RuntimeError('dataset boom')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'dataset boom'):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_warns_and_continues_when_swanlab_log_and_finish_fail(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
agent = FakeAgent()
fake_swanlab = FakeSwanLab(
log_errors=[RuntimeError('log backend hiccup')],
finish_error=RuntimeError('finish backend hiccup'),
)
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch.object(module.log, 'warning') as warning_mock:
run_training(cfg)
finally:
os.chdir(previous_cwd)
warning_messages = [call.args[0] for call in warning_mock.call_args_list]
self.assertTrue(any('SwanLab log failed' in message for message in warning_messages))
self.assertTrue(any('SwanLab finish failed' in message for message in warning_messages))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_resume_restores_best_rollout_baseline_from_best_checkpoint(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 2
cfg.train.save_freq = 1
cfg.train.rollout_validate_on_checkpoint = True
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
saved_paths = []
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'best')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
'rollout_avg_reward': 5.0,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
def fake_torch_save(payload, path):
saved_paths.append(str(path))
return None
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', side_effect=fake_torch_save), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module), \
mock.patch('roboimi.demos.vla_scripts.eval_vla._run_eval', return_value={'avg_reward': 3.0}):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_best.pt')
self.assertNotIn('checkpoints/vla_model_best.pt', saved_paths)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
fake_optimizer = FakeOptimizer(lr=cfg.train.lr)
fake_scheduler = FakeScheduler()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
resume_path = checkpoint_dir / 'vla_model_step_0.pt'
resume_path.write_bytes(b'resume')
best_path = checkpoint_dir / 'vla_model_best.pt'
best_path.write_bytes(b'stale')
cfg.train.resume_ckpt = str(resume_path)
resume_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.5,
'val_loss': 0.25,
}
stale_best_checkpoint_state = {
'step': 0,
'model_state_dict': FakeAgent().state_dict(),
'optimizer_state_dict': {},
'scheduler_state_dict': {},
'loss': 0.4,
'val_loss': 0.2,
}
def fake_torch_load(path, map_location=None):
del map_location
path = Path(path)
if path == resume_path:
return resume_checkpoint_state
if path == best_path:
return stale_best_checkpoint_state
raise AssertionError(f'unexpected load path: {path}')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'build_training_optimizer', return_value=fake_optimizer), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=fake_scheduler), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], 'checkpoints/vla_model_step_0.pt')
def test_run_training_ignores_stale_best_checkpoint_file_on_fresh_non_resume_run(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
checkpoint_dir = Path('checkpoints')
checkpoint_dir.mkdir()
(checkpoint_dir / 'vla_model_best.pt').write_bytes(b'stale-best')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
final_payload, final_step = fake_swanlab.log_calls[-1]
self.assertEqual(final_step, cfg.train.max_steps)
self.assertEqual(final_payload['final/best_checkpoint_path'], '')
def test_run_training_fails_fast_when_swanlab_import_is_unavailable(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
raise ImportError('missing swanlab')
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'SwanLab'):
run_training(cfg)
def test_run_training_fails_fast_when_swanlab_init_fails(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
fake_swanlab = FakeSwanLab(init_error=RuntimeError('not logged in'))
real_import_module = importlib.import_module
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with mock.patch.object(module, 'instantiate', side_effect=AssertionError('instantiate should not run')), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
with self.assertRaisesRegex(RuntimeError, 'not logged in'):
run_training(cfg)
self.assertEqual(fake_swanlab.finish_calls, 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,310 @@
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import torch
from torch import nn
_REPO_ROOT = Path(__file__).resolve().parents[1]
_TRAIN_VLA_PATH = _REPO_ROOT / 'roboimi/demos/vla_scripts/train_vla.py'
class AttrDict(dict):
def __getattr__(self, name):
try:
return self[name]
except KeyError as exc:
raise AttributeError(name) from exc
def __setattr__(self, name, value):
self[name] = value
class FakeDataset:
def __len__(self):
return 4
class FakeLoader:
def __len__(self):
return 1
def __iter__(self):
return iter(())
class FakeScheduler:
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingAdamW:
created = []
def __init__(self, params, lr, weight_decay):
self.lr = lr
self.weight_decay = weight_decay
self.param_groups = self._normalize_param_groups(params, lr, weight_decay)
RecordingAdamW.created.append(self)
@staticmethod
def _normalize_param_groups(params, lr, weight_decay):
if isinstance(params, (list, tuple)) and params and isinstance(params[0], dict):
groups = []
for group in params:
normalized = dict(group)
normalized['params'] = list(group['params'])
normalized.setdefault('lr', lr)
groups.append(normalized)
return groups
return [{
'params': list(params),
'lr': lr,
'weight_decay': weight_decay,
}]
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
return None
class RecordingTransformerHead(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Linear(4, 4)
self.norm = nn.LayerNorm(4)
self.optim_group_calls = []
def get_optim_groups(self, weight_decay):
self.optim_group_calls.append(weight_decay)
return [
{
'params': [self.proj.weight],
'weight_decay': weight_decay,
},
{
'params': [self.proj.bias, self.norm.weight, self.norm.bias],
'weight_decay': 0.0,
},
]
class FakeTransformerAgent(nn.Module):
def __init__(self):
super().__init__()
self.head_type = 'transformer'
self.noise_pred_net = RecordingTransformerHead()
self.backbone = nn.Linear(4, 3)
self.adapter = nn.Linear(3, 2, bias=False)
self.frozen = nn.Linear(2, 2)
for param in self.frozen.parameters():
param.requires_grad = False
def to(self, device):
return self
def get_normalization_stats(self):
return {}
class TrainVLATransformerOptimizerTest(unittest.TestCase):
def setUp(self):
RecordingAdamW.created = []
def _load_train_vla_module(self):
hydra_module = types.ModuleType('hydra')
hydra_utils_module = types.ModuleType('hydra.utils')
hydra_utils_module.instantiate = lambda *args, **kwargs: None
def hydra_main(**_kwargs):
def decorator(func):
return func
return decorator
hydra_module.main = hydra_main
hydra_module.utils = hydra_utils_module
class OmegaConfStub:
_resolvers = {}
@classmethod
def has_resolver(cls, name):
return name in cls._resolvers
@classmethod
def register_new_resolver(cls, name, resolver):
cls._resolvers[name] = resolver
@staticmethod
def to_yaml(_cfg):
return 'stub-config'
omegaconf_module = types.ModuleType('omegaconf')
omegaconf_module.DictConfig = dict
omegaconf_module.OmegaConf = OmegaConfStub
module_name = 'train_vla_optimizer_test_module'
spec = importlib.util.spec_from_file_location(module_name, _TRAIN_VLA_PATH)
module = importlib.util.module_from_spec(spec)
with mock.patch.dict(
sys.modules,
{
'hydra': hydra_module,
'hydra.utils': hydra_utils_module,
'omegaconf': omegaconf_module,
},
):
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def _make_cfg(self):
return AttrDict(
train=AttrDict(
device='cpu',
batch_size=2,
num_workers=0,
val_split=0,
seed=0,
lr=1e-4,
max_steps=0,
log_freq=1,
save_freq=100,
warmup_steps=1,
scheduler_type='constant',
min_lr=0.0,
grad_clip=1.0,
weight_decay=0.123,
pretrained_ckpt=None,
resume_ckpt=None,
),
data=AttrDict(
camera_names=('front',),
),
agent=AttrDict(
_target_='fake.agent',
),
)
def _group_names(self, agent, optimizer):
names_by_param_id = {id(param): name for name, param in agent.named_parameters()}
return [
{names_by_param_id[id(param)] for param in group['params']}
for group in optimizer.param_groups
]
def test_transformer_training_prefers_head_optim_groups_and_keeps_remaining_trainable_params(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
optimizer = RecordingAdamW.created[-1]
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
grouped_names = self._group_names(agent, optimizer)
optimizer_names = set().union(*grouped_names)
expected_head_names = {
'noise_pred_net.proj.weight',
'noise_pred_net.proj.bias',
'noise_pred_net.norm.weight',
'noise_pred_net.norm.bias',
}
expected_non_head_names = {
'backbone.weight',
'backbone.bias',
'adapter.weight',
}
self.assertEqual(grouped_names[0], {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[1], expected_head_names - {'noise_pred_net.proj.weight'})
self.assertEqual(grouped_names[2], expected_non_head_names)
self.assertEqual(optimizer.param_groups[0]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer.param_groups[1]['weight_decay'], 0.0)
self.assertEqual(optimizer.param_groups[2]['weight_decay'], cfg.train.weight_decay)
self.assertEqual(optimizer_names, trainable_names)
flattened_param_ids = [
id(param)
for group in optimizer.param_groups
for param in group['params']
]
self.assertEqual(len(flattened_param_ids), len(set(flattened_param_ids)))
self.assertNotIn('frozen.weight', optimizer_names)
self.assertNotIn('frozen.bias', optimizer_names)
def test_transformer_optimizer_ignores_frozen_head_params_returned_by_head_groups(self):
module = self._load_train_vla_module()
agent = FakeTransformerAgent()
agent.noise_pred_net.norm.bias.requires_grad = False
cfg = self._make_cfg()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=lambda *args, **kwargs: FakeLoader()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'AdamW', RecordingAdamW), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: iterable):
module.main(cfg)
finally:
os.chdir(previous_cwd)
optimizer = RecordingAdamW.created[-1]
optimizer_names = set().union(*self._group_names(agent, optimizer))
trainable_names = {
name for name, param in agent.named_parameters() if param.requires_grad
}
self.assertEqual(agent.noise_pred_net.optim_group_calls, [cfg.train.weight_decay])
self.assertEqual(optimizer_names, trainable_names)
self.assertNotIn('noise_pred_net.norm.bias', optimizer_names)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,262 @@
import contextlib
import importlib.util
import inspect
import sys
import types
import unittest
import warnings
from pathlib import Path
import torch
_REPO_ROOT = Path(__file__).resolve().parents[1]
_LOCAL_MODULE_PATH = _REPO_ROOT / 'roboimi/vla/models/heads/transformer1d.py'
_EXTERNAL_CHECKOUT_ROOT = _REPO_ROOT.parent / 'diffusion_policy'
_TRANSFORMER_WARNING_MESSAGE = (
r'enable_nested_tensor is True, but self.use_nested_tensor is False '
r'because encoder_layer\.norm_first was True'
)
_MISSING = object()
def _load_module_from_path(name: str, path: Path, *, register: bool = False):
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
if register:
sys.modules[name] = module
spec.loader.exec_module(module)
return module
def _resolve_external_module_paths(external_checkout_root: Path):
diffusion_policy_root = external_checkout_root / 'diffusion_policy'
paths = {
'positional_embedding': diffusion_policy_root / 'model/diffusion/positional_embedding.py',
'module_attr_mixin': diffusion_policy_root / 'model/common/module_attr_mixin.py',
'transformer_for_diffusion': diffusion_policy_root / 'model/diffusion/transformer_for_diffusion.py',
}
if not all(path.exists() for path in paths.values()):
return None
return paths
@contextlib.contextmanager
def _temporary_registered_modules():
previous_modules = {}
def remember(name: str) -> None:
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
def ensure_package(name: str) -> None:
if not name or name in sys.modules:
return
remember(name)
package = types.ModuleType(name)
package.__path__ = []
sys.modules[name] = package
def load(name: str, path: Path):
package_parts = name.split('.')[:-1]
for idx in range(1, len(package_parts) + 1):
ensure_package('.'.join(package_parts[:idx]))
remember(name)
return _load_module_from_path(name, path, register=True)
try:
yield load
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
@contextlib.contextmanager
def _suppress_nested_tensor_warning():
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore',
message=_TRANSFORMER_WARNING_MESSAGE,
category=UserWarning,
module=r'torch\.nn\.modules\.transformer',
)
yield
def _load_local_module():
return _load_module_from_path('local_transformer1d_alignment', _LOCAL_MODULE_PATH)
class Transformer1DExternalAlignmentTest(unittest.TestCase):
def _load_transformer_classes_or_skip(self):
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
if external_paths is None:
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
local_module = _load_local_module()
with _temporary_registered_modules() as load_external:
load_external(
'diffusion_policy.model.diffusion.positional_embedding',
external_paths['positional_embedding'],
)
load_external(
'diffusion_policy.model.common.module_attr_mixin',
external_paths['module_attr_mixin'],
)
external_module = load_external(
'diffusion_policy.model.diffusion.transformer_for_diffusion',
external_paths['transformer_for_diffusion'],
)
return local_module.Transformer1D, local_module.create_transformer1d, external_module.TransformerForDiffusion
def _optim_group_names(self, model, groups):
names_by_param = {id(param): name for name, param in model.named_parameters()}
return [
{names_by_param[id(param)] for param in group['params']}
for group in groups
]
def test_missing_external_checkout_resolution_returns_none(self):
self.assertIsNone(_resolve_external_module_paths(_REPO_ROOT / '__missing_diffusion_policy_checkout__'))
def test_external_loader_restores_injected_sys_modules(self):
external_paths = _resolve_external_module_paths(_EXTERNAL_CHECKOUT_ROOT)
if external_paths is None:
self.skipTest(f'external diffusion_policy checkout unavailable under {_EXTERNAL_CHECKOUT_ROOT}')
watched_names = [
'diffusion_policy',
'diffusion_policy.model',
'diffusion_policy.model.common',
'diffusion_policy.model.common.module_attr_mixin',
'diffusion_policy.model.diffusion',
'diffusion_policy.model.diffusion.positional_embedding',
'diffusion_policy.model.diffusion.transformer_for_diffusion',
]
before = {name: sys.modules.get(name, _MISSING) for name in watched_names}
with _temporary_registered_modules() as load_external:
load_external(
'diffusion_policy.model.diffusion.positional_embedding',
external_paths['positional_embedding'],
)
load_external(
'diffusion_policy.model.common.module_attr_mixin',
external_paths['module_attr_mixin'],
)
load_external(
'diffusion_policy.model.diffusion.transformer_for_diffusion',
external_paths['transformer_for_diffusion'],
)
after = {name: sys.modules.get(name, _MISSING) for name in watched_names}
self.assertEqual(after, before)
def test_transformer1d_preserves_local_direct_call_defaults(self):
local_module = _load_local_module()
ctor = inspect.signature(local_module.Transformer1D.__init__).parameters
helper = inspect.signature(local_module.create_transformer1d).parameters
self.assertEqual(ctor['n_layer'].default, 8)
self.assertEqual(ctor['n_head'].default, 8)
self.assertEqual(ctor['n_emb'].default, 256)
self.assertEqual(helper['n_layer'].default, 8)
self.assertEqual(helper['n_head'].default, 8)
self.assertEqual(helper['n_emb'].default, 256)
def test_time_as_cond_false_token_accounting_matches_external(self):
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
self.assertIn('time_as_cond', inspect.signature(Transformer1D.__init__).parameters)
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=0,
n_layer=2,
n_head=2,
n_emb=8,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
time_as_cond=False,
obs_as_cond=False,
n_cond_layers=0,
)
torch.manual_seed(5)
with _suppress_nested_tensor_warning():
external_model = TransformerForDiffusion(**config)
local_model = Transformer1D(**config)
external_model.eval()
local_model.eval()
self.assertEqual(local_model.T, external_model.T)
self.assertEqual(local_model.T_cond, external_model.T_cond)
self.assertEqual(local_model.time_as_cond, external_model.time_as_cond)
self.assertEqual(local_model.obs_as_cond, external_model.obs_as_cond)
self.assertEqual(local_model.encoder_only, external_model.encoder_only)
def test_nocausal_state_dict_forward_and_optim_groups_match_external(self):
Transformer1D, _, TransformerForDiffusion = self._load_transformer_classes_or_skip()
config = dict(
input_dim=4,
output_dim=4,
horizon=6,
n_obs_steps=3,
cond_dim=5,
n_layer=2,
n_head=2,
n_emb=8,
p_drop_emb=0.0,
p_drop_attn=0.0,
causal_attn=False,
obs_as_cond=True,
n_cond_layers=1,
)
torch.manual_seed(7)
with _suppress_nested_tensor_warning():
external_model = TransformerForDiffusion(**config)
local_model = Transformer1D(**config)
external_model.eval()
local_model.eval()
external_state_dict = external_model.state_dict()
self.assertEqual(set(local_model.state_dict().keys()), set(external_state_dict.keys()))
local_model.load_state_dict(external_state_dict, strict=True)
batch_size = 2
sample = torch.randn(batch_size, config['horizon'], config['input_dim'])
cond = torch.randn(batch_size, config['n_obs_steps'], config['cond_dim'])
timestep = torch.tensor([11, 17], dtype=torch.long)
with torch.no_grad():
external_out = external_model(sample=sample, timestep=timestep, cond=cond)
local_out = local_model(sample=sample, timestep=timestep, cond=cond)
self.assertEqual(local_out.shape, (batch_size, config['horizon'], config['output_dim']))
self.assertEqual(local_out.shape, external_out.shape)
self.assertTrue(torch.allclose(local_out, external_out, atol=1e-6, rtol=1e-5))
weight_decay = 0.123
external_groups = external_model.get_optim_groups(weight_decay=weight_decay)
local_groups = local_model.get_optim_groups(weight_decay=weight_decay)
self.assertEqual(len(local_groups), len(external_groups))
self.assertEqual([group['weight_decay'] for group in local_groups], [weight_decay, 0.0])
self.assertEqual(
self._optim_group_names(local_model, local_groups),
self._optim_group_names(external_model, external_groups),
)
if __name__ == '__main__':
unittest.main()