388 lines
14 KiB
Python
388 lines
14 KiB
Python
import contextlib
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from hydra import compose, initialize_config_dir
|
|
from hydra.errors import InstantiationException
|
|
from hydra.core.global_hydra import GlobalHydra
|
|
from hydra.utils import instantiate
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
|
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
|
|
_MISSING = object()
|
|
|
|
|
|
class _FakeScheduler:
|
|
def __init__(self, num_train_timesteps=100, **kwargs):
|
|
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
|
self.timesteps = []
|
|
|
|
def add_noise(self, sample, noise, timestep):
|
|
return sample + noise
|
|
|
|
def set_timesteps(self, num_inference_steps):
|
|
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
|
|
|
def step(self, noise_pred, timestep, sample):
|
|
return types.SimpleNamespace(prev_sample=sample)
|
|
|
|
|
|
class _IdentityCrop:
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __call__(self, x):
|
|
return x
|
|
|
|
|
|
class _FakeResNet(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
|
self.relu1 = torch.nn.ReLU()
|
|
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
|
self.relu2 = torch.nn.ReLU()
|
|
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = torch.nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
x = self.relu1(self.conv1(x))
|
|
x = self.relu2(self.conv2(x))
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, start_dim=1)
|
|
return self.fc(x)
|
|
|
|
|
|
class _FakeRearrange(torch.nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class _CondCapturingHead(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.last_cond = None
|
|
|
|
def forward(self, sample, timestep, cond):
|
|
self.last_cond = cond.detach().clone()
|
|
return torch.zeros_like(sample)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _stub_optional_modules():
|
|
previous_modules = {}
|
|
|
|
def inject(name, module):
|
|
if name not in previous_modules:
|
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
|
sys.modules[name] = module
|
|
|
|
diffusers_module = types.ModuleType('diffusers')
|
|
schedulers_module = types.ModuleType('diffusers.schedulers')
|
|
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
|
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
|
ddpm_module.DDPMScheduler = _FakeScheduler
|
|
ddim_module.DDIMScheduler = _FakeScheduler
|
|
diffusers_module.DDPMScheduler = _FakeScheduler
|
|
diffusers_module.DDIMScheduler = _FakeScheduler
|
|
diffusers_module.schedulers = schedulers_module
|
|
schedulers_module.scheduling_ddpm = ddpm_module
|
|
schedulers_module.scheduling_ddim = ddim_module
|
|
|
|
torchvision_module = types.ModuleType('torchvision')
|
|
models_module = types.ModuleType('torchvision.models')
|
|
transforms_module = types.ModuleType('torchvision.transforms')
|
|
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
|
transforms_module.CenterCrop = _IdentityCrop
|
|
transforms_module.RandomCrop = _IdentityCrop
|
|
torchvision_module.models = models_module
|
|
torchvision_module.transforms = transforms_module
|
|
|
|
einops_module = types.ModuleType('einops')
|
|
einops_module.rearrange = lambda x, *args, **kwargs: x
|
|
einops_layers_module = types.ModuleType('einops.layers')
|
|
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
|
einops_layers_torch_module.Rearrange = _FakeRearrange
|
|
einops_module.layers = einops_layers_module
|
|
einops_layers_module.torch = einops_layers_torch_module
|
|
|
|
try:
|
|
inject('diffusers', diffusers_module)
|
|
inject('diffusers.schedulers', schedulers_module)
|
|
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
|
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
|
inject('torchvision', torchvision_module)
|
|
inject('torchvision.models', models_module)
|
|
inject('torchvision.transforms', transforms_module)
|
|
inject('einops', einops_module)
|
|
inject('einops.layers', einops_layers_module)
|
|
inject('einops.layers.torch', einops_layers_torch_module)
|
|
yield
|
|
finally:
|
|
for name, previous in reversed(list(previous_modules.items())):
|
|
if previous is _MISSING:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = previous
|
|
|
|
|
|
def _compose_cfg(overrides=None):
|
|
if not OmegaConf.has_resolver('len'):
|
|
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
|
|
|
GlobalHydra.instance().clear()
|
|
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
|
return compose(config_name='config', overrides=list(overrides or []))
|
|
|
|
|
|
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
|
|
channels, height, width = image_shape
|
|
per_camera_fill = per_camera_fill or {
|
|
'front': 30.0,
|
|
'top': 20.0,
|
|
'r_vis': 10.0,
|
|
}
|
|
return {
|
|
name: torch.full(
|
|
(batch_size, obs_horizon, channels, height, width),
|
|
fill_value=fill_value,
|
|
dtype=torch.float32,
|
|
)
|
|
for name, fill_value in per_camera_fill.items()
|
|
}
|
|
|
|
|
|
def _patch_backbone_for_order_tracking(backbone):
|
|
feature_dim = backbone.output_dim
|
|
|
|
def encode_mean(image_batch):
|
|
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
|
|
return mean_feature.repeat(1, feature_dim)
|
|
|
|
if backbone.use_separate_rgb_encoder_per_camera:
|
|
for encoder in backbone.rgb_encoder:
|
|
encoder.forward_single_image = encode_mean
|
|
else:
|
|
backbone.rgb_encoder.forward_single_image = encode_mean
|
|
|
|
|
|
def _extract_camera_markers(cond, feature_dim, num_cams):
|
|
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
|
|
return camera_block[:, 0]
|
|
|
|
|
|
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
|
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
'agent.inference_steps=1',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_cond_layers=0',
|
|
'agent.head.n_emb=32',
|
|
'agent.head.n_head=4',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
|
|
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
|
|
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
|
|
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
|
|
self.assertEqual(cfg.agent.head_type, 'transformer')
|
|
self.assertEqual(cfg.agent.num_cams, 3)
|
|
self.assertTrue(cfg.agent.head.obs_as_cond)
|
|
self.assertFalse(cfg.agent.head.causal_attn)
|
|
|
|
with _stub_optional_modules():
|
|
agent = instantiate(cfg.agent)
|
|
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
|
|
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
|
|
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
|
|
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
|
|
|
|
batch_size = 2
|
|
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
|
|
images = _make_images(
|
|
batch_size,
|
|
cfg.agent.obs_horizon,
|
|
image_shape,
|
|
per_camera_fill={
|
|
'front': 30.0,
|
|
'top': 20.0,
|
|
'r_vis': 10.0,
|
|
'left_wrist': 99.0,
|
|
},
|
|
)
|
|
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
|
_patch_backbone_for_order_tracking(agent.vision_encoder)
|
|
capturing_head = _CondCapturingHead()
|
|
agent.noise_pred_net = capturing_head
|
|
predicted_actions = agent.predict_action(images, proprioception)
|
|
self.assertEqual(
|
|
predicted_actions.shape,
|
|
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
|
|
)
|
|
self.assertIsNotNone(capturing_head.last_cond)
|
|
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
|
|
camera_markers = _extract_camera_markers(
|
|
capturing_head.last_cond,
|
|
agent.vision_encoder.output_dim,
|
|
agent.num_cams,
|
|
)
|
|
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
|
|
|
missing_images = dict(images)
|
|
missing_images.pop('top')
|
|
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
|
agent.predict_action(missing_images, proprioception)
|
|
|
|
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
]
|
|
)
|
|
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
|
|
|
|
with _stub_optional_modules():
|
|
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
|
|
instantiate(cfg.agent)
|
|
|
|
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
]
|
|
)
|
|
cfg.agent.vision_backbone.camera_names = None
|
|
|
|
with _stub_optional_modules():
|
|
backbone = instantiate(cfg.agent.vision_backbone)
|
|
_patch_backbone_for_order_tracking(backbone)
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=cfg.agent.obs_horizon,
|
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
|
per_camera_fill={
|
|
'top': 20.0,
|
|
'front': 30.0,
|
|
'r_vis': 10.0,
|
|
},
|
|
)
|
|
ordered_features = backbone(images)
|
|
camera_markers = _extract_camera_markers(
|
|
ordered_features,
|
|
backbone.output_dim,
|
|
len(images),
|
|
)
|
|
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
|
|
|
|
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
]
|
|
)
|
|
cfg.agent.camera_names = None
|
|
cfg.agent.vision_backbone.camera_names = None
|
|
|
|
with _stub_optional_modules():
|
|
agent = instantiate(cfg.agent)
|
|
observation = {
|
|
'qpos': torch.randn(cfg.agent.obs_dim),
|
|
'images': {
|
|
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
|
|
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
|
|
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
|
|
},
|
|
}
|
|
agent._populate_queues(observation)
|
|
batch = agent._prepare_observation_batch()
|
|
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
|
|
|
|
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
]
|
|
)
|
|
cfg.agent.vision_backbone.camera_names = None
|
|
|
|
with _stub_optional_modules():
|
|
backbone = instantiate(cfg.agent.vision_backbone)
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=cfg.agent.obs_horizon,
|
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
|
per_camera_fill={
|
|
'front': 30.0,
|
|
'r_vis': 10.0,
|
|
},
|
|
)
|
|
with self.assertRaisesRegex(ValueError, 'num_cameras'):
|
|
backbone(images)
|
|
|
|
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
'agent.inference_steps=1',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_cond_layers=0',
|
|
'agent.head.n_emb=32',
|
|
'agent.head.n_head=4',
|
|
]
|
|
)
|
|
cfg.agent.camera_names = None
|
|
cfg.agent.vision_backbone.camera_names = None
|
|
|
|
with _stub_optional_modules():
|
|
agent = instantiate(cfg.agent)
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=cfg.agent.obs_horizon,
|
|
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
|
per_camera_fill={
|
|
'front': 30.0,
|
|
'r_vis': 10.0,
|
|
},
|
|
)
|
|
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
|
with self.assertRaisesRegex(ValueError, 'num_cams'):
|
|
agent.predict_action(images, proprioception)
|
|
|
|
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
]
|
|
)
|
|
cfg.agent.camera_names = None
|
|
cfg.agent.vision_backbone.camera_names = None
|
|
cfg.agent.num_cams = 2
|
|
cfg.agent.vision_backbone.num_cameras = 3
|
|
|
|
with _stub_optional_modules():
|
|
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
|
instantiate(cfg.agent)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|