Files
roboimi/tests/test_resnet_transformer_agent_wiring.py

456 lines
17 KiB
Python

import contextlib
import sys
import types
import unittest
from pathlib import Path
import torch
from hydra import compose, initialize_config_dir
from hydra.errors import InstantiationException
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
_REPO_ROOT = Path(__file__).resolve().parents[1]
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
_MISSING = object()
class _FakeScheduler:
def __init__(self, num_train_timesteps=100, **kwargs):
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
self.timesteps = []
def add_noise(self, sample, noise, timestep):
return sample + noise
def set_timesteps(self, num_inference_steps):
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
def step(self, noise_pred, timestep, sample):
return types.SimpleNamespace(prev_sample=sample)
class _IdentityCrop:
def __init__(self, size):
self.size = size
def __call__(self, x):
return x
class _FakeResNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
self.relu1 = torch.nn.ReLU()
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
self.relu2 = torch.nn.ReLU()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
self.fc = torch.nn.Linear(16, 16)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.avgpool(x)
x = torch.flatten(x, start_dim=1)
return self.fc(x)
class _FakeRearrange(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x):
return x
class _CondCapturingHead(torch.nn.Module):
def __init__(self):
super().__init__()
self.last_cond = None
def forward(self, sample, timestep, cond):
self.last_cond = cond.detach().clone()
return torch.zeros_like(sample)
@contextlib.contextmanager
def _stub_optional_modules():
previous_modules = {}
def inject(name, module):
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
sys.modules[name] = module
diffusers_module = types.ModuleType('diffusers')
schedulers_module = types.ModuleType('diffusers.schedulers')
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
ddpm_module.DDPMScheduler = _FakeScheduler
ddim_module.DDIMScheduler = _FakeScheduler
diffusers_module.DDPMScheduler = _FakeScheduler
diffusers_module.DDIMScheduler = _FakeScheduler
diffusers_module.schedulers = schedulers_module
schedulers_module.scheduling_ddpm = ddpm_module
schedulers_module.scheduling_ddim = ddim_module
torchvision_module = types.ModuleType('torchvision')
models_module = types.ModuleType('torchvision.models')
transforms_module = types.ModuleType('torchvision.transforms')
models_module.resnet18 = lambda weights=None: _FakeResNet()
transforms_module.CenterCrop = _IdentityCrop
transforms_module.RandomCrop = _IdentityCrop
torchvision_module.models = models_module
torchvision_module.transforms = transforms_module
einops_module = types.ModuleType('einops')
einops_module.rearrange = lambda x, *args, **kwargs: x
einops_layers_module = types.ModuleType('einops.layers')
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
einops_layers_torch_module.Rearrange = _FakeRearrange
einops_module.layers = einops_layers_module
einops_layers_module.torch = einops_layers_torch_module
try:
inject('diffusers', diffusers_module)
inject('diffusers.schedulers', schedulers_module)
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
inject('torchvision', torchvision_module)
inject('torchvision.models', models_module)
inject('torchvision.transforms', transforms_module)
inject('einops', einops_module)
inject('einops.layers', einops_layers_module)
inject('einops.layers.torch', einops_layers_torch_module)
yield
finally:
for name, previous in reversed(list(previous_modules.items())):
if previous is _MISSING:
sys.modules.pop(name, None)
else:
sys.modules[name] = previous
def _compose_cfg(overrides=None):
if not OmegaConf.has_resolver('len'):
OmegaConf.register_new_resolver('len', lambda x: len(x))
GlobalHydra.instance().clear()
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
return compose(config_name='config', overrides=list(overrides or []))
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
channels, height, width = image_shape
per_camera_fill = per_camera_fill or {
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
}
return {
name: torch.full(
(batch_size, obs_horizon, channels, height, width),
fill_value=fill_value,
dtype=torch.float32,
)
for name, fill_value in per_camera_fill.items()
}
def _patch_backbone_for_order_tracking(backbone):
feature_dim = backbone.output_dim
def encode_mean(image_batch):
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
return mean_feature.repeat(1, feature_dim)
if backbone.use_separate_rgb_encoder_per_camera:
for encoder in backbone.rgb_encoder:
encoder.forward_single_image = encode_mean
else:
backbone.rgb_encoder.forward_single_image = encode_mean
def _extract_camera_markers(cond, feature_dim, num_cams):
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
return camera_block[:, 0]
def _extract_token_camera_markers(tokens):
return tokens[0, 0, :, 0]
def _extract_token_markers(token_sequence):
return token_sequence[0, 0, :, 0]
class ResNetTransformerAgentWiringTest(unittest.TestCase):
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.inference_steps=1',
'agent.head.n_layer=1',
'agent.head.n_cond_layers=0',
'agent.head.n_emb=32',
'agent.head.n_head=4',
]
)
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
self.assertEqual(cfg.agent.head_type, 'transformer')
self.assertEqual(cfg.agent.num_cams, 3)
self.assertTrue(cfg.agent.head.obs_as_cond)
self.assertFalse(cfg.agent.head.causal_attn)
with _stub_optional_modules():
agent = instantiate(cfg.agent)
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
batch_size = 2
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
images = _make_images(
batch_size,
cfg.agent.obs_horizon,
image_shape,
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
'left_wrist': 99.0,
},
)
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
_patch_backbone_for_order_tracking(agent.vision_encoder)
capturing_head = _CondCapturingHead()
agent.noise_pred_net = capturing_head
predicted_actions = agent.predict_action(images, proprioception)
self.assertEqual(
predicted_actions.shape,
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
)
self.assertIsNotNone(capturing_head.last_cond)
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
camera_markers = _extract_camera_markers(
capturing_head.last_cond,
agent.vision_encoder.output_dim,
agent.num_cams,
)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
missing_images = dict(images)
missing_images.pop('top')
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
agent.predict_action(missing_images, proprioception)
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
'left_wrist': 99.0,
},
)
tokens = backbone(images)
self.assertEqual(tokens.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
self.assertEqual(backbone.tokens_per_step, 3)
camera_markers = _extract_token_camera_markers(tokens)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
with _stub_optional_modules():
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
instantiate(cfg.agent)
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'top': 20.0,
'front': 30.0,
'r_vis': 10.0,
},
)
ordered_features = backbone(images)
camera_markers = _extract_camera_markers(
ordered_features,
backbone.output_dim,
len(images),
)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
agent = instantiate(cfg.agent)
observation = {
'qpos': torch.randn(cfg.agent.obs_dim),
'images': {
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
},
}
agent._populate_queues(observation)
batch = agent._prepare_observation_batch()
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'r_vis': 10.0,
},
)
with self.assertRaisesRegex(ValueError, 'num_cameras'):
backbone(images)
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.inference_steps=1',
'agent.head.n_layer=1',
'agent.head.n_cond_layers=0',
'agent.head.n_emb=32',
'agent.head.n_head=4',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
with _stub_optional_modules():
agent = instantiate(cfg.agent)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'r_vis': 10.0,
},
)
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
with self.assertRaisesRegex(ValueError, 'num_cams'):
agent.predict_action(images, proprioception)
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
cfg = _compose_cfg(
overrides=[
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
cfg.agent.camera_names = None
cfg.agent.vision_backbone.camera_names = None
cfg.agent.num_cams = 2
cfg.agent.vision_backbone.num_cameras = 3
with _stub_optional_modules():
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
instantiate(cfg.agent)
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.head.n_layer=1',
'agent.head.n_emb=32',
]
)
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
},
)
output = backbone(images)
self.assertEqual(output.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
token_markers = _extract_token_markers(output)
self.assertTrue(torch.allclose(token_markers, torch.tensor([10.0, 20.0, 30.0])))
if __name__ == '__main__':
unittest.main()