import contextlib import sys import types import unittest from pathlib import Path import torch from hydra import compose, initialize_config_dir from hydra.errors import InstantiationException from hydra.core.global_hydra import GlobalHydra from hydra.utils import instantiate from omegaconf import OmegaConf _REPO_ROOT = Path(__file__).resolve().parents[1] _CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve()) _EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front'] _MISSING = object() class _FakeScheduler: def __init__(self, num_train_timesteps=100, **kwargs): self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps) self.timesteps = [] def add_noise(self, sample, noise, timestep): return sample + noise def set_timesteps(self, num_inference_steps): self.timesteps = list(range(num_inference_steps - 1, -1, -1)) def step(self, noise_pred, timestep, sample): return types.SimpleNamespace(prev_sample=sample) class _IdentityCrop: def __init__(self, size): self.size = size def __call__(self, x): return x class _FakeResNet(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1) self.relu1 = torch.nn.ReLU() self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2) self.relu2 = torch.nn.ReLU() self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) self.fc = torch.nn.Linear(16, 16) def forward(self, x): x = self.relu1(self.conv1(x)) x = self.relu2(self.conv2(x)) x = self.avgpool(x) x = torch.flatten(x, start_dim=1) return self.fc(x) class _FakeRearrange(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): return x class _CondCapturingHead(torch.nn.Module): def __init__(self): super().__init__() self.last_cond = None def forward(self, sample, timestep, cond): self.last_cond = cond.detach().clone() return torch.zeros_like(sample) @contextlib.contextmanager def _stub_optional_modules(): previous_modules = {} def inject(name, module): if name not in previous_modules: previous_modules[name] = sys.modules.get(name, _MISSING) sys.modules[name] = module diffusers_module = types.ModuleType('diffusers') schedulers_module = types.ModuleType('diffusers.schedulers') ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm') ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim') ddpm_module.DDPMScheduler = _FakeScheduler ddim_module.DDIMScheduler = _FakeScheduler diffusers_module.DDPMScheduler = _FakeScheduler diffusers_module.DDIMScheduler = _FakeScheduler diffusers_module.schedulers = schedulers_module schedulers_module.scheduling_ddpm = ddpm_module schedulers_module.scheduling_ddim = ddim_module torchvision_module = types.ModuleType('torchvision') models_module = types.ModuleType('torchvision.models') transforms_module = types.ModuleType('torchvision.transforms') models_module.resnet18 = lambda weights=None: _FakeResNet() transforms_module.CenterCrop = _IdentityCrop transforms_module.RandomCrop = _IdentityCrop torchvision_module.models = models_module torchvision_module.transforms = transforms_module einops_module = types.ModuleType('einops') einops_module.rearrange = lambda x, *args, **kwargs: x einops_layers_module = types.ModuleType('einops.layers') einops_layers_torch_module = types.ModuleType('einops.layers.torch') einops_layers_torch_module.Rearrange = _FakeRearrange einops_module.layers = einops_layers_module einops_layers_module.torch = einops_layers_torch_module try: inject('diffusers', diffusers_module) inject('diffusers.schedulers', schedulers_module) inject('diffusers.schedulers.scheduling_ddpm', ddpm_module) inject('diffusers.schedulers.scheduling_ddim', ddim_module) inject('torchvision', torchvision_module) inject('torchvision.models', models_module) inject('torchvision.transforms', transforms_module) inject('einops', einops_module) inject('einops.layers', einops_layers_module) inject('einops.layers.torch', einops_layers_torch_module) yield finally: for name, previous in reversed(list(previous_modules.items())): if previous is _MISSING: sys.modules.pop(name, None) else: sys.modules[name] = previous def _compose_cfg(overrides=None): if not OmegaConf.has_resolver('len'): OmegaConf.register_new_resolver('len', lambda x: len(x)) GlobalHydra.instance().clear() with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR): return compose(config_name='config', overrides=list(overrides or [])) def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None): channels, height, width = image_shape per_camera_fill = per_camera_fill or { 'front': 30.0, 'top': 20.0, 'r_vis': 10.0, } return { name: torch.full( (batch_size, obs_horizon, channels, height, width), fill_value=fill_value, dtype=torch.float32, ) for name, fill_value in per_camera_fill.items() } def _patch_backbone_for_order_tracking(backbone): feature_dim = backbone.output_dim def encode_mean(image_batch): mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1) return mean_feature.repeat(1, feature_dim) if backbone.use_separate_rgb_encoder_per_camera: for encoder in backbone.rgb_encoder: encoder.forward_single_image = encode_mean else: backbone.rgb_encoder.forward_single_image = encode_mean def _extract_camera_markers(cond, feature_dim, num_cams): camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim) return camera_block[:, 0] class ResNetTransformerAgentWiringTest(unittest.TestCase): def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', 'agent.inference_steps=1', 'agent.head.n_layer=1', 'agent.head.n_cond_layers=0', 'agent.head.n_emb=32', 'agent.head.n_head=4', ] ) self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES) self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES) self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES) self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES) self.assertEqual(cfg.agent.head_type, 'transformer') self.assertEqual(cfg.agent.num_cams, 3) self.assertTrue(cfg.agent.head.obs_as_cond) self.assertFalse(cfg.agent.head.causal_attn) with _stub_optional_modules(): agent = instantiate(cfg.agent) expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim) self.assertEqual(agent.per_step_cond_dim, expected_cond_dim) self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim) batch_size = 2 image_shape = tuple(cfg.agent.vision_backbone.input_shape) images = _make_images( batch_size, cfg.agent.obs_horizon, image_shape, per_camera_fill={ 'front': 30.0, 'top': 20.0, 'r_vis': 10.0, 'left_wrist': 99.0, }, ) proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim) _patch_backbone_for_order_tracking(agent.vision_encoder) capturing_head = _CondCapturingHead() agent.noise_pred_net = capturing_head predicted_actions = agent.predict_action(images, proprioception) self.assertEqual( predicted_actions.shape, (batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim), ) self.assertIsNotNone(capturing_head.last_cond) self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim) camera_markers = _extract_camera_markers( capturing_head.last_cond, agent.vision_encoder.output_dim, agent.num_cams, ) self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0]))) missing_images = dict(images) missing_images.pop('top') with self.assertRaisesRegex(ValueError, 'missing=.*top'): agent.predict_action(missing_images, proprioception) def test_agent_rejects_conflicting_explicit_backbone_camera_names(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', ] ) cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis'] with _stub_optional_modules(): with self.assertRaisesRegex(InstantiationException, 'camera_names'): instantiate(cfg.agent) def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', ] ) cfg.agent.vision_backbone.camera_names = None with _stub_optional_modules(): backbone = instantiate(cfg.agent.vision_backbone) _patch_backbone_for_order_tracking(backbone) images = _make_images( batch_size=1, obs_horizon=cfg.agent.obs_horizon, image_shape=tuple(cfg.agent.vision_backbone.input_shape), per_camera_fill={ 'top': 20.0, 'front': 30.0, 'r_vis': 10.0, }, ) ordered_features = backbone(images) camera_markers = _extract_camera_markers( ordered_features, backbone.output_dim, len(images), ) self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0]))) def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', ] ) cfg.agent.camera_names = None cfg.agent.vision_backbone.camera_names = None with _stub_optional_modules(): agent = instantiate(cfg.agent) observation = { 'qpos': torch.randn(cfg.agent.obs_dim), 'images': { 'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0), 'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0), 'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0), }, } agent._populate_queues(observation) batch = agent._prepare_observation_batch() self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top']) def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', ] ) cfg.agent.vision_backbone.camera_names = None with _stub_optional_modules(): backbone = instantiate(cfg.agent.vision_backbone) images = _make_images( batch_size=1, obs_horizon=cfg.agent.obs_horizon, image_shape=tuple(cfg.agent.vision_backbone.input_shape), per_camera_fill={ 'front': 30.0, 'r_vis': 10.0, }, ) with self.assertRaisesRegex(ValueError, 'num_cameras'): backbone(images) def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', 'agent.inference_steps=1', 'agent.head.n_layer=1', 'agent.head.n_cond_layers=0', 'agent.head.n_emb=32', 'agent.head.n_head=4', ] ) cfg.agent.camera_names = None cfg.agent.vision_backbone.camera_names = None with _stub_optional_modules(): agent = instantiate(cfg.agent) images = _make_images( batch_size=1, obs_horizon=cfg.agent.obs_horizon, image_shape=tuple(cfg.agent.vision_backbone.input_shape), per_camera_fill={ 'front': 30.0, 'r_vis': 10.0, }, ) proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim) with self.assertRaisesRegex(ValueError, 'num_cams'): agent.predict_action(images, proprioception) def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self): cfg = _compose_cfg( overrides=[ 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', ] ) cfg.agent.camera_names = None cfg.agent.vision_backbone.camera_names = None cfg.agent.num_cams = 2 cfg.agent.vision_backbone.num_cameras = 3 with _stub_optional_modules(): with self.assertRaisesRegex(InstantiationException, 'num_cams'): instantiate(cfg.agent) if __name__ == '__main__': unittest.main()