feat(vla): align transformer training stack and rollout validation
This commit is contained in:
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
387
tests/test_resnet_transformer_agent_wiring.py
Normal file
@@ -0,0 +1,387 @@
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from hydra.errors import InstantiationException
|
||||
from hydra.core.global_hydra import GlobalHydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
||||
_EXPECTED_CAMERA_NAMES = ['r_vis', 'top', 'front']
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
class _FakeScheduler:
|
||||
def __init__(self, num_train_timesteps=100, **kwargs):
|
||||
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
||||
self.timesteps = []
|
||||
|
||||
def add_noise(self, sample, noise, timestep):
|
||||
return sample + noise
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
||||
|
||||
def step(self, noise_pred, timestep, sample):
|
||||
return types.SimpleNamespace(prev_sample=sample)
|
||||
|
||||
|
||||
class _IdentityCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeResNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
||||
self.relu1 = torch.nn.ReLU()
|
||||
self.conv2 = torch.nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
||||
self.relu2 = torch.nn.ReLU()
|
||||
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||||
self.fc = torch.nn.Linear(16, 16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.conv1(x))
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, start_dim=1)
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
class _FakeRearrange(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class _CondCapturingHead(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.last_cond = None
|
||||
|
||||
def forward(self, sample, timestep, cond):
|
||||
self.last_cond = cond.detach().clone()
|
||||
return torch.zeros_like(sample)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _stub_optional_modules():
|
||||
previous_modules = {}
|
||||
|
||||
def inject(name, module):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules[name] = module
|
||||
|
||||
diffusers_module = types.ModuleType('diffusers')
|
||||
schedulers_module = types.ModuleType('diffusers.schedulers')
|
||||
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
||||
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
||||
ddpm_module.DDPMScheduler = _FakeScheduler
|
||||
ddim_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.DDPMScheduler = _FakeScheduler
|
||||
diffusers_module.DDIMScheduler = _FakeScheduler
|
||||
diffusers_module.schedulers = schedulers_module
|
||||
schedulers_module.scheduling_ddpm = ddpm_module
|
||||
schedulers_module.scheduling_ddim = ddim_module
|
||||
|
||||
torchvision_module = types.ModuleType('torchvision')
|
||||
models_module = types.ModuleType('torchvision.models')
|
||||
transforms_module = types.ModuleType('torchvision.transforms')
|
||||
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||
transforms_module.CenterCrop = _IdentityCrop
|
||||
transforms_module.RandomCrop = _IdentityCrop
|
||||
torchvision_module.models = models_module
|
||||
torchvision_module.transforms = transforms_module
|
||||
|
||||
einops_module = types.ModuleType('einops')
|
||||
einops_module.rearrange = lambda x, *args, **kwargs: x
|
||||
einops_layers_module = types.ModuleType('einops.layers')
|
||||
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
||||
einops_layers_torch_module.Rearrange = _FakeRearrange
|
||||
einops_module.layers = einops_layers_module
|
||||
einops_layers_module.torch = einops_layers_torch_module
|
||||
|
||||
try:
|
||||
inject('diffusers', diffusers_module)
|
||||
inject('diffusers.schedulers', schedulers_module)
|
||||
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
||||
inject('torchvision', torchvision_module)
|
||||
inject('torchvision.models', models_module)
|
||||
inject('torchvision.transforms', transforms_module)
|
||||
inject('einops', einops_module)
|
||||
inject('einops.layers', einops_layers_module)
|
||||
inject('einops.layers.torch', einops_layers_torch_module)
|
||||
yield
|
||||
finally:
|
||||
for name, previous in reversed(list(previous_modules.items())):
|
||||
if previous is _MISSING:
|
||||
sys.modules.pop(name, None)
|
||||
else:
|
||||
sys.modules[name] = previous
|
||||
|
||||
|
||||
def _compose_cfg(overrides=None):
|
||||
if not OmegaConf.has_resolver('len'):
|
||||
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
||||
|
||||
GlobalHydra.instance().clear()
|
||||
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
||||
return compose(config_name='config', overrides=list(overrides or []))
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, image_shape, per_camera_fill=None):
|
||||
channels, height, width = image_shape
|
||||
per_camera_fill = per_camera_fill or {
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
}
|
||||
return {
|
||||
name: torch.full(
|
||||
(batch_size, obs_horizon, channels, height, width),
|
||||
fill_value=fill_value,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for name, fill_value in per_camera_fill.items()
|
||||
}
|
||||
|
||||
|
||||
def _patch_backbone_for_order_tracking(backbone):
|
||||
feature_dim = backbone.output_dim
|
||||
|
||||
def encode_mean(image_batch):
|
||||
mean_feature = image_batch.mean(dim=(1, 2, 3)).unsqueeze(-1)
|
||||
return mean_feature.repeat(1, feature_dim)
|
||||
|
||||
if backbone.use_separate_rgb_encoder_per_camera:
|
||||
for encoder in backbone.rgb_encoder:
|
||||
encoder.forward_single_image = encode_mean
|
||||
else:
|
||||
backbone.rgb_encoder.forward_single_image = encode_mean
|
||||
|
||||
|
||||
def _extract_camera_markers(cond, feature_dim, num_cams):
|
||||
camera_block = cond[0, 0, : feature_dim * num_cams].view(num_cams, feature_dim)
|
||||
return camera_block[:, 0]
|
||||
|
||||
|
||||
class ResNetTransformerAgentWiringTest(unittest.TestCase):
|
||||
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.inference_steps=1',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_cond_layers=0',
|
||||
'agent.head.n_emb=32',
|
||||
'agent.head.n_head=4',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(list(cfg.data.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.eval.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.agent.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), _EXPECTED_CAMERA_NAMES)
|
||||
self.assertEqual(cfg.agent.head_type, 'transformer')
|
||||
self.assertEqual(cfg.agent.num_cams, 3)
|
||||
self.assertTrue(cfg.agent.head.obs_as_cond)
|
||||
self.assertFalse(cfg.agent.head.causal_attn)
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
expected_cond_dim = agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim
|
||||
self.assertEqual(cfg.agent.head.cond_dim, expected_cond_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, expected_cond_dim)
|
||||
self.assertEqual(agent.noise_pred_net.cond_obs_emb.in_features, expected_cond_dim)
|
||||
|
||||
batch_size = 2
|
||||
image_shape = tuple(cfg.agent.vision_backbone.input_shape)
|
||||
images = _make_images(
|
||||
batch_size,
|
||||
cfg.agent.obs_horizon,
|
||||
image_shape,
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'top': 20.0,
|
||||
'r_vis': 10.0,
|
||||
'left_wrist': 99.0,
|
||||
},
|
||||
)
|
||||
proprioception = torch.randn(batch_size, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||
_patch_backbone_for_order_tracking(agent.vision_encoder)
|
||||
capturing_head = _CondCapturingHead()
|
||||
agent.noise_pred_net = capturing_head
|
||||
predicted_actions = agent.predict_action(images, proprioception)
|
||||
self.assertEqual(
|
||||
predicted_actions.shape,
|
||||
(batch_size, cfg.agent.pred_horizon, cfg.agent.action_dim),
|
||||
)
|
||||
self.assertIsNotNone(capturing_head.last_cond)
|
||||
self.assertEqual(capturing_head.last_cond.shape[-1], expected_cond_dim)
|
||||
camera_markers = _extract_camera_markers(
|
||||
capturing_head.last_cond,
|
||||
agent.vision_encoder.output_dim,
|
||||
agent.num_cams,
|
||||
)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
|
||||
|
||||
missing_images = dict(images)
|
||||
missing_images.pop('top')
|
||||
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
|
||||
agent.predict_action(missing_images, proprioception)
|
||||
|
||||
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = ['front', 'top', 'r_vis']
|
||||
|
||||
with _stub_optional_modules():
|
||||
with self.assertRaisesRegex(InstantiationException, 'camera_names'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
def test_backbone_uses_sorted_fallback_order_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
_patch_backbone_for_order_tracking(backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'top': 20.0,
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
ordered_features = backbone(images)
|
||||
camera_markers = _extract_camera_markers(
|
||||
ordered_features,
|
||||
backbone.output_dim,
|
||||
len(images),
|
||||
)
|
||||
self.assertTrue(torch.allclose(camera_markers, torch.tensor([30.0, 10.0, 20.0])))
|
||||
|
||||
def test_agent_queue_fallback_order_is_deterministic_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
observation = {
|
||||
'qpos': torch.randn(cfg.agent.obs_dim),
|
||||
'images': {
|
||||
'top': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 20.0),
|
||||
'front': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 30.0),
|
||||
'r_vis': torch.full(tuple(cfg.agent.vision_backbone.input_shape), 10.0),
|
||||
},
|
||||
}
|
||||
agent._populate_queues(observation)
|
||||
batch = agent._prepare_observation_batch()
|
||||
self.assertEqual(list(batch['images'].keys()), ['front', 'r_vis', 'top'])
|
||||
|
||||
def test_backbone_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
backbone = instantiate(cfg.agent.vision_backbone)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, 'num_cameras'):
|
||||
backbone(images)
|
||||
|
||||
def test_agent_rejects_camera_count_mismatch_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.inference_steps=1',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_cond_layers=0',
|
||||
'agent.head.n_emb=32',
|
||||
'agent.head.n_head=4',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
|
||||
with _stub_optional_modules():
|
||||
agent = instantiate(cfg.agent)
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=cfg.agent.obs_horizon,
|
||||
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
|
||||
per_camera_fill={
|
||||
'front': 30.0,
|
||||
'r_vis': 10.0,
|
||||
},
|
||||
)
|
||||
proprioception = torch.randn(1, cfg.agent.obs_horizon, cfg.agent.obs_dim)
|
||||
with self.assertRaisesRegex(ValueError, 'num_cams'):
|
||||
agent.predict_action(images, proprioception)
|
||||
|
||||
def test_agent_rejects_num_cams_mismatch_with_backbone_when_camera_names_unset(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
]
|
||||
)
|
||||
cfg.agent.camera_names = None
|
||||
cfg.agent.vision_backbone.camera_names = None
|
||||
cfg.agent.num_cams = 2
|
||||
cfg.agent.vision_backbone.num_cameras = 3
|
||||
|
||||
with _stub_optional_modules():
|
||||
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
|
||||
instantiate(cfg.agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user