import contextlib import importlib import importlib.machinery import sys import types import unittest from pathlib import Path from unittest import mock import torch from hydra import compose, initialize_config_dir from hydra.core.global_hydra import GlobalHydra from hydra.utils import instantiate from omegaconf import OmegaConf from torch import nn _REPO_ROOT = Path(__file__).resolve().parents[1] _CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve()) _MISSING = object() _CAMERA_NAMES = ('r_vis', 'top', 'front') class _FakeScheduler: def __init__(self, num_train_timesteps=100, **kwargs): self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps) self.timesteps = [] def add_noise(self, sample, noise, timestep): return sample + noise def set_timesteps(self, num_inference_steps): self.timesteps = list(range(num_inference_steps - 1, -1, -1)) def step(self, noise_pred, timestep, sample): return types.SimpleNamespace(prev_sample=sample) class _IdentityCrop: def __init__(self, size): self.size = size def __call__(self, x): return x class _FakeResNet(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2) self.relu2 = nn.ReLU() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(16, 16) def forward(self, x): x = self.relu1(self.conv1(x)) x = self.relu2(self.conv2(x)) x = self.avgpool(x) x = torch.flatten(x, start_dim=1) return self.fc(x) class _FakeRearrange(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, x): return x class _FakeViTConfig: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) class _FakeViTModel(nn.Module): def __init__(self, config, add_pooling_layer=False): super().__init__() del add_pooling_layer self.config = config hidden_size = int(getattr(config, 'hidden_size', 192)) self.proj = nn.Linear(hidden_size, hidden_size) def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs): del interpolate_pos_encoding, kwargs batch_size = pixel_values.shape[0] hidden_size = int(getattr(self.config, 'hidden_size', 192)) seq_len = 2 last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device) return types.SimpleNamespace(last_hidden_state=last_hidden_state) class _FakeSiglipVisionOutput: def __init__(self, pooler_output): self.pooler_output = pooler_output class _FakeSiglipVisionConfig: def __init__(self, hidden_size=768, image_size=256): self.hidden_size = hidden_size self.image_size = image_size class _FakeSiglipVisionModel(nn.Module): load_calls = [] def __init__(self, hidden_size=768): super().__init__() self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size) self.scale = nn.Parameter(torch.tensor(1.0)) self.forward_calls = [] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): model = cls() cls.load_calls.append({ 'pretrained_model_name_or_path': pretrained_model_name_or_path, 'args': args, 'kwargs': kwargs, }) return model def forward(self, pixel_values=None, **kwargs): self.forward_calls.append({ 'pixel_values': pixel_values.detach().clone(), 'kwargs': dict(kwargs), }) pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale return _FakeSiglipVisionOutput(pooler_output=pooled) class _StubIMFHead(nn.Module): def __init__( self, input_dim, output_dim, horizon, n_obs_steps, cond_dim, **kwargs, ): super().__init__() self.constructor_kwargs = { 'input_dim': input_dim, 'output_dim': output_dim, 'horizon': horizon, 'n_obs_steps': n_obs_steps, 'cond_dim': cond_dim, **kwargs, } self.proj = nn.Linear(input_dim, output_dim) self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1)) def forward(self, sample, r, t, cond=None): return torch.zeros_like(sample) def get_optim_groups(self, weight_decay): return [ {'params': [self.proj.weight], 'weight_decay': weight_decay}, {'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0}, ] @contextlib.contextmanager def _stub_optional_modules(include_imf_head=False): previous_modules = {} def remember_and_remove(name): if name not in previous_modules: previous_modules[name] = sys.modules.get(name, _MISSING) sys.modules.pop(name, None) def inject(name, module): if name not in previous_modules: previous_modules[name] = sys.modules.get(name, _MISSING) sys.modules[name] = module diffusers_module = types.ModuleType('diffusers') schedulers_module = types.ModuleType('diffusers.schedulers') ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm') ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim') ddpm_module.DDPMScheduler = _FakeScheduler ddim_module.DDIMScheduler = _FakeScheduler diffusers_module.DDPMScheduler = _FakeScheduler diffusers_module.DDIMScheduler = _FakeScheduler diffusers_module.schedulers = schedulers_module schedulers_module.scheduling_ddpm = ddpm_module schedulers_module.scheduling_ddim = ddim_module torchvision_module = types.ModuleType('torchvision') models_module = types.ModuleType('torchvision.models') transforms_module = types.ModuleType('torchvision.transforms') torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None) models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None) transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None) models_module.resnet18 = lambda weights=None: _FakeResNet() transforms_module.CenterCrop = _IdentityCrop transforms_module.RandomCrop = _IdentityCrop torchvision_module.models = models_module torchvision_module.transforms = transforms_module einops_module = types.ModuleType('einops') einops_module.rearrange = lambda x, *args, **kwargs: x einops_layers_module = types.ModuleType('einops.layers') einops_layers_torch_module = types.ModuleType('einops.layers.torch') einops_layers_torch_module.Rearrange = _FakeRearrange einops_module.layers = einops_layers_module einops_layers_module.torch = einops_layers_torch_module transformers_module = types.ModuleType('transformers') transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None) transformers_module.ViTConfig = _FakeViTConfig transformers_module.ViTModel = _FakeViTModel transformers_module.SiglipVisionModel = _FakeSiglipVisionModel try: remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone') inject('diffusers', diffusers_module) inject('diffusers.schedulers', schedulers_module) inject('diffusers.schedulers.scheduling_ddpm', ddpm_module) inject('diffusers.schedulers.scheduling_ddim', ddim_module) inject('torchvision', torchvision_module) inject('torchvision.models', models_module) inject('torchvision.transforms', transforms_module) inject('einops', einops_module) inject('einops.layers', einops_layers_module) inject('einops.layers.torch', einops_layers_torch_module) inject('transformers', transformers_module) if include_imf_head: import roboimi.vla.models.heads as heads_package imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d') imf_head_module.IMFTransformer1D = _StubIMFHead inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module) setattr(heads_package, 'imf_transformer1d', imf_head_module) yield finally: for name, previous in reversed(list(previous_modules.items())): if previous is _MISSING: sys.modules.pop(name, None) else: sys.modules[name] = previous def _compose_cfg(overrides=None): if not OmegaConf.has_resolver('len'): OmegaConf.register_new_resolver('len', lambda x: len(x)) GlobalHydra.instance().clear() with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR): return compose(config_name='config', overrides=list(overrides or [])) def _load_imf_agent_class(): with _stub_optional_modules(): sys.modules.pop('roboimi.vla.agent_imf', None) module = importlib.import_module('roboimi.vla.agent_imf') return module.IMFVLAAgent, module class _StubVisionBackbone(nn.Module): output_dim = 1 def __init__(self, camera_names=_CAMERA_NAMES): super().__init__() self.camera_names = tuple(camera_names) self.num_cameras = len(self.camera_names) def forward(self, images): per_camera_features = [] for camera_name in self.camera_names: image_batch = images[camera_name] per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)) return torch.cat(per_camera_features, dim=-1) class _StubJointVisionBackbone(nn.Module): joint_output_dim = 5 output_dim = 5 def __init__(self, camera_names=_CAMERA_NAMES): super().__init__() self.camera_names = tuple(camera_names) self.num_cameras = len(self.camera_names) def forward(self, images): batch_size, obs_horizon = next(iter(images.values())).shape[:2] features = [] for camera_name in ('front', 'top', 'r_vis'): image_batch = images[camera_name] features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)) joint_features = torch.cat(features, dim=-1) front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True) r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1]) time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1) time_marker = time_marker.expand(batch_size, -1, -1) return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1) class _StubMultiTokenVisionBackbone(nn.Module): output_dim = 2 tokens_per_step = 3 def __init__(self, camera_names=_CAMERA_NAMES): super().__init__() self.camera_names = tuple(camera_names) self.num_cameras = len(self.camera_names) def forward(self, images): batch_size, obs_horizon = next(iter(images.values())).shape[:2] features = [] time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1) for camera_name in self.camera_names: image_batch = images[camera_name] camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1) features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1)) return torch.stack(features, dim=2) class _StubMultiTokenVisionBackbone(nn.Module): output_dim = 2 tokens_per_step = 3 def __init__(self, camera_names=_CAMERA_NAMES): super().__init__() self.camera_names = tuple(camera_names) self.num_cameras = len(self.camera_names) def forward(self, images): per_camera = [] for camera_name in self.camera_names: image_batch = images[camera_name] base = image_batch.mean(dim=(2, 3, 4), keepdim=False) per_camera.append(torch.stack([base, base + 0.5], dim=-1)) return torch.stack(per_camera, dim=2) class _RecordingLinearIMFHead(nn.Module): def __init__(self): super().__init__() self.scale = nn.Parameter(torch.tensor(0.5)) self.calls = [] @staticmethod def _broadcast_batch_time(value, reference): while value.ndim < reference.ndim: value = value.unsqueeze(-1) return value def forward(self, sample, r, t, cond=None): record = { 'sample': sample.detach().clone(), 'r': r.detach().clone(), 't': t.detach().clone(), 'cond': None if cond is None else cond.detach().clone(), } self.calls.append(record) cond_term = 0.0 if cond is not None: cond_term = cond.mean(dim=(1, 2), keepdim=True) r_b = self._broadcast_batch_time(r, sample) t_b = self._broadcast_batch_time(t, sample) return self.scale * sample + r_b + 2.0 * t_b + cond_term class _ForbiddenScheduler: def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps') def step(self, *args, **kwargs): # pragma: no cover - only runs on regression raise AssertionError('IMF inference should not use DDIM scheduler step') def _make_images(batch_size, obs_horizon, per_camera_fill): return { name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32) for name, value in per_camera_fill.items() } class IMFVLAAgentTest(unittest.TestCase): def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2): agent_cls, agent_module = _load_imf_agent_class() head = _RecordingLinearIMFHead() agent = agent_cls( vision_backbone=_StubVisionBackbone(), state_encoder=nn.Identity(), action_encoder=nn.Identity(), head=head, action_dim=2, obs_dim=1, pred_horizon=pred_horizon, obs_horizon=obs_horizon, diffusion_steps=10, inference_steps=1, num_cams=len(_CAMERA_NAMES), camera_names=_CAMERA_NAMES, num_action_steps=num_action_steps, head_type='transformer', ) return agent, head, agent_module def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self): agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2) images = _make_images( batch_size=1, obs_horizon=2, per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0}, ) qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32) actions = torch.tensor( [[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]], dtype=torch.float32, ) action_is_pad = torch.tensor([[False, False, True]]) noise = torch.tensor( [[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]], dtype=torch.float32, ) t_sample = torch.tensor([0.8], dtype=torch.float32) r_sample = torch.tensor([0.25], dtype=torch.float32) with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \ mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]): loss = agent.compute_loss( { 'images': images, 'qpos': qpos, 'action': actions, 'action_is_pad': action_is_pad, } ) cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32) cond_term = cond.mean(dim=(1, 2), keepdim=True) t = t_sample r = r_sample z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise scale = head.scale.detach() u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term du_dt = scale * v + 2.0 compound_velocity = u + (t - r).view(1, 1, 1) * du_dt target = noise - actions elementwise_loss = (compound_velocity - target) ** 2 mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype) expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1]) self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6) self.assertEqual(len(head.calls), 2) self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample)) self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample)) self.assertTrue(torch.allclose(head.calls[0]['cond'], cond)) def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self): agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2) agent.infer_scheduler = _ForbiddenScheduler() images = _make_images( batch_size=2, obs_horizon=2, per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, ) qpos = torch.tensor( [ [[1.0], [2.0]], [[3.0], [4.0]], ], dtype=torch.float32, ) initial_noise = torch.tensor( [ [[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]], [[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]], ], dtype=torch.float32, ) with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise): predicted_actions = agent.predict_action(images, qpos) expected_cond = torch.tensor( [ [[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]], [[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]], ], dtype=torch.float32, ) cond_term = expected_cond.mean(dim=(1, 2), keepdim=True) expected_actions = 0.5 * initial_noise - 2.0 - cond_term self.assertEqual(predicted_actions.shape, (2, 3, 2)) self.assertTrue(torch.allclose(predicted_actions, expected_actions)) self.assertEqual(len(head.calls), 1) self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2))) self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2))) self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) def test_select_action_only_regenerates_when_action_queue_is_empty(self): agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2) observation = { 'qpos': torch.tensor([0.25], dtype=torch.float32), 'images': { 'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32), 'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32), 'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32), }, } first_chunk = torch.tensor( [[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]], dtype=torch.float32, ) second_chunk = torch.tensor( [[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]], dtype=torch.float32, ) with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk: first_action = agent.select_action(observation) second_action = agent.select_action(observation) third_action = agent.select_action(observation) self.assertTrue(torch.equal(first_action, first_chunk[0, 1])) self.assertTrue(torch.equal(second_action, first_chunk[0, 2])) self.assertTrue(torch.equal(third_action, second_chunk[0, 1])) self.assertEqual(mock_predict_chunk.call_count, 2) def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self): agent_cls, _agent_module = _load_imf_agent_class() head = _RecordingLinearIMFHead() vision_backbone = _StubJointVisionBackbone() agent = agent_cls( vision_backbone=vision_backbone, state_encoder=nn.Identity(), action_encoder=nn.Identity(), head=head, action_dim=2, obs_dim=1, pred_horizon=3, obs_horizon=2, diffusion_steps=10, inference_steps=1, num_cams=len(_CAMERA_NAMES), camera_names=_CAMERA_NAMES, num_action_steps=2, head_type='transformer', ) self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim) self.assertEqual( agent.global_cond_dim, vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon, ) images = _make_images( batch_size=1, obs_horizon=2, per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, ) qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32) initial_noise = torch.tensor( [[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32, ) with mock.patch.object(torch, 'randn', return_value=initial_noise): predicted_actions = agent.predict_action(images, qpos) self.assertEqual(predicted_actions.shape, (1, 3, 2)) self.assertEqual(len(head.calls), 1) expected_cond = torch.tensor( [[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]], dtype=torch.float32, ) self.assertEqual(head.calls[0]['cond'].shape[-1], 6) self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self): agent_cls, _agent_module = _load_imf_agent_class() head = _RecordingLinearIMFHead() projector = nn.Linear(3, 4, bias=False) with torch.no_grad(): projector.weight.copy_( torch.tensor( [ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], ], dtype=torch.float32, ) ) agent = agent_cls( vision_backbone=_StubMultiTokenVisionBackbone(), state_encoder=nn.Identity(), action_encoder=nn.Identity(), head=head, action_dim=2, obs_dim=1, pred_horizon=3, obs_horizon=2, diffusion_steps=10, inference_steps=1, num_cams=len(_CAMERA_NAMES), camera_names=_CAMERA_NAMES, num_action_steps=2, head_type='transformer', cond_projector=projector, ) self.assertEqual(agent.condition_tokens_per_step, 3) self.assertEqual(agent.condition_sequence_length, 6) self.assertEqual(agent.per_step_cond_dim, 4) self.assertEqual(agent.global_cond_dim, 24) images = _make_images( batch_size=1, obs_horizon=2, per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, ) qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32) cond = agent._build_cond(images, qpos) expected = torch.tensor( [ [ [10.0, 10.5, 1.0, 11.0], [20.0, 20.5, 1.0, 21.0], [30.0, 30.5, 1.0, 31.0], [10.0, 10.5, 2.0, 12.0], [20.0, 20.5, 2.0, 22.0], [30.0, 30.5, 2.0, 32.0], ] ], dtype=torch.float32, ) self.assertEqual(cond.shape, (1, 6, 4)) self.assertTrue(torch.allclose(cond, expected)) def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self): agent_cls, agent_module = _load_imf_agent_class() head = _RecordingLinearIMFHead() cond_projector = nn.Linear(3, 4, bias=False) with torch.no_grad(): cond_projector.weight.copy_(torch.tensor([ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 1.0], ], dtype=torch.float32)) agent = agent_cls( vision_backbone=_StubMultiTokenVisionBackbone(), state_encoder=nn.Identity(), action_encoder=nn.Identity(), head=head, action_dim=2, obs_dim=1, pred_horizon=3, obs_horizon=2, diffusion_steps=10, inference_steps=1, num_cams=len(_CAMERA_NAMES), camera_names=_CAMERA_NAMES, num_action_steps=2, head_type='transformer', cond_projector=cond_projector, ) agent.infer_scheduler = _ForbiddenScheduler() images = _make_images( batch_size=1, obs_horizon=2, per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0}, ) qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32) initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32) with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise): predicted_actions = agent.predict_action(images, qpos) expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0], [20.0, 20.5, 1.0, 21.0], [30.0, 30.5, 1.0, 31.0], [10.0, 10.5, 2.0, 12.0], [20.0, 20.5, 2.0, 22.0], [30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32) self.assertEqual(agent.condition_tokens_per_step, 3) self.assertEqual(agent.condition_sequence_length, 6) self.assertEqual(agent.raw_per_step_cond_dim, 3) self.assertEqual(agent.per_step_cond_dim, 4) self.assertEqual(agent.global_cond_dim, 24) self.assertEqual(predicted_actions.shape, (1, 3, 2)) self.assertEqual(len(head.calls), 1) self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4)) self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond)) def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self): cfg = _compose_cfg( overrides=[ 'agent=resnet_imf_attnres', 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', 'agent.vision_backbone.freeze_backbone=false', 'agent.head.n_layer=1', 'agent.head.n_emb=16', ] ) self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D') self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full') self.assertEqual(cfg.agent.head.n_head, 1) self.assertEqual(cfg.agent.head.n_kv_head, 1) self.assertEqual(cfg.agent.head.n_cond_layers, 0) self.assertTrue(cfg.agent.head.time_as_cond) self.assertFalse(cfg.agent.head.causal_attn) self.assertEqual(cfg.agent.inference_steps, 1) self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES)) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.head_type, 'transformer') self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim) self.assertIsInstance(agent.noise_pred_net, _StubIMFHead) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full') def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self): cfg = _compose_cfg( overrides=[ 'agent=resnet_imf_attnres', 'agent.vision_backbone.vision_backbone_mode=attnres_resnet', 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,56,56]', 'agent.vision_backbone.freeze_backbone=false', 'agent.vision_backbone.attnres_stem_dim=16', 'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]', 'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]', 'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]', 'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]', 'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]', 'agent.head.n_layer=1', 'agent.head.n_emb=16', ] ) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.vision_encoder.output_dim, 64) self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim) def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self): cfg = _compose_cfg( overrides=[ 'agent=lewm_imf_attnres', 'agent.vision_backbone.checkpoint_path=null', 'agent.head.n_layer=1', 'agent.head.n_emb=16', ] ) self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone') self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES)) self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES)) self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis']) self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape) self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256]) self.assertEqual(cfg.agent.head.cond_dim, 208) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim) self.assertEqual(agent.per_step_cond_dim, 208) self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208) self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape) self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256)) self.assertIsInstance(agent.noise_pred_net, _StubIMFHead) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208) def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self): cfg = _compose_cfg( overrides=[ 'agent=resnet_imf_attnres_multitoken', 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', 'agent.head.n_layer=1', 'agent.head.n_emb=32', ] ) self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet') self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera) self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera) self.assertEqual(cfg.agent.cond_projector.output_dim, 32) self.assertEqual(cfg.agent.head.cond_dim, 32) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.condition_tokens_per_step, 3) self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3) self.assertEqual(agent.per_step_cond_dim, 32) self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32) self.assertIsInstance(agent.noise_pred_net, _StubIMFHead) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32) self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6) def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self): cfg = _compose_cfg( overrides=[ 'agent=siglip2_imf_attnres', 'agent.vision_backbone.per_view_output_dim=96', 'agent.head.n_layer=1', 'agent.head.n_emb=16', 'agent.cond_projector.output_dim=384', ] ) self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') self.assertEqual( cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone', ) self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES)) self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape) self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256]) self.assertEqual(cfg.agent.head.cond_dim, 384) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim) self.assertEqual(agent.per_step_cond_dim, 384) self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384) self.assertEqual(agent.vision_encoder.output_dim, 96) self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256)) def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self): cfg = _compose_cfg( overrides=[ 'agent=resnet_imf_attnres_multitoken', 'agent.vision_backbone.pretrained_backbone_weights=null', 'agent.vision_backbone.input_shape=[3,16,16]', 'agent.vision_backbone.freeze_backbone=false', 'agent.head.n_layer=1', 'agent.head.n_emb=16', ] ) self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent') self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES)) self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera) self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera) self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet') self.assertEqual(cfg.agent.cond_projector.output_dim, 16) self.assertEqual(cfg.agent.head.cond_dim, 16) with _stub_optional_modules(include_imf_head=True): agent = instantiate(cfg.agent) self.assertEqual(agent.condition_tokens_per_step, 3) self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3) self.assertEqual(agent.per_step_cond_dim, 16) self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16) self.assertEqual(agent.vision_encoder.tokens_per_step, 3) self.assertIsInstance(agent.noise_pred_net, _StubIMFHead) self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16) self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length) if __name__ == '__main__': unittest.main()