feat: add vision transfer backbones and IMF variants

This commit is contained in:
Logic
2026-04-09 14:02:24 +08:00
parent d51b3ecafa
commit ff7c9c1f2a
58 changed files with 2788 additions and 26 deletions

View File

@@ -90,6 +90,24 @@ class _FakeRenderer:
class EvalVLAHeadlessTest(unittest.TestCase):
def test_prepare_observation_skips_resize_when_image_resize_shape_is_none(self):
obs = {
"images": {
"front": np.arange(8 * 8 * 3, dtype=np.uint8).reshape(8, 8, 3),
},
"qpos": np.zeros(16, dtype=np.float32),
}
with mock.patch("cv2.resize", side_effect=AssertionError("resize should be skipped")):
prepared = eval_vla.prepare_observation(
obs,
["front"],
image_resize_shape=None,
)
self.assertEqual(tuple(prepared["images"]["front"].shape), (3, 8, 8))
self.assertEqual(tuple(prepared["qpos"].shape), (16,))
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
cfg = OmegaConf.create({"eval": {"headless": True}})
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):

View File

@@ -1,5 +1,6 @@
import contextlib
import importlib
import importlib.machinery
import sys
import types
import unittest
@@ -69,6 +70,68 @@ class _FakeRearrange(nn.Module):
return x
class _FakeViTConfig:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
class _FakeViTModel(nn.Module):
def __init__(self, config, add_pooling_layer=False):
super().__init__()
del add_pooling_layer
self.config = config
hidden_size = int(getattr(config, 'hidden_size', 192))
self.proj = nn.Linear(hidden_size, hidden_size)
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
del interpolate_pos_encoding, kwargs
batch_size = pixel_values.shape[0]
hidden_size = int(getattr(self.config, 'hidden_size', 192))
seq_len = 2
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
class _FakeSiglipVisionOutput:
def __init__(self, pooler_output):
self.pooler_output = pooler_output
class _FakeSiglipVisionConfig:
def __init__(self, hidden_size=768, image_size=256):
self.hidden_size = hidden_size
self.image_size = image_size
class _FakeSiglipVisionModel(nn.Module):
load_calls = []
def __init__(self, hidden_size=768):
super().__init__()
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
self.scale = nn.Parameter(torch.tensor(1.0))
self.forward_calls = []
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = cls()
cls.load_calls.append({
'pretrained_model_name_or_path': pretrained_model_name_or_path,
'args': args,
'kwargs': kwargs,
})
return model
def forward(self, pixel_values=None, **kwargs):
self.forward_calls.append({
'pixel_values': pixel_values.detach().clone(),
'kwargs': dict(kwargs),
})
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
return _FakeSiglipVisionOutput(pooler_output=pooled)
class _StubIMFHead(nn.Module):
def __init__(
self,
@@ -105,6 +168,11 @@ class _StubIMFHead(nn.Module):
def _stub_optional_modules(include_imf_head=False):
previous_modules = {}
def remember_and_remove(name):
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
sys.modules.pop(name, None)
def inject(name, module):
if name not in previous_modules:
previous_modules[name] = sys.modules.get(name, _MISSING)
@@ -125,6 +193,9 @@ def _stub_optional_modules(include_imf_head=False):
torchvision_module = types.ModuleType('torchvision')
models_module = types.ModuleType('torchvision.models')
transforms_module = types.ModuleType('torchvision.transforms')
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
models_module.resnet18 = lambda weights=None: _FakeResNet()
transforms_module.CenterCrop = _IdentityCrop
transforms_module.RandomCrop = _IdentityCrop
@@ -139,7 +210,14 @@ def _stub_optional_modules(include_imf_head=False):
einops_module.layers = einops_layers_module
einops_layers_module.torch = einops_layers_torch_module
transformers_module = types.ModuleType('transformers')
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
transformers_module.ViTConfig = _FakeViTConfig
transformers_module.ViTModel = _FakeViTModel
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
try:
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
inject('diffusers', diffusers_module)
inject('diffusers.schedulers', schedulers_module)
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
@@ -150,6 +228,7 @@ def _stub_optional_modules(include_imf_head=False):
inject('einops', einops_module)
inject('einops.layers', einops_layers_module)
inject('einops.layers.torch', einops_layers_torch_module)
inject('transformers', transformers_module)
if include_imf_head:
import roboimi.vla.models.heads as heads_package
@@ -200,6 +279,67 @@ class _StubVisionBackbone(nn.Module):
return torch.cat(per_camera_features, dim=-1)
class _StubJointVisionBackbone(nn.Module):
joint_output_dim = 5
output_dim = 5
def __init__(self, camera_names=_CAMERA_NAMES):
super().__init__()
self.camera_names = tuple(camera_names)
self.num_cameras = len(self.camera_names)
def forward(self, images):
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
features = []
for camera_name in ('front', 'top', 'r_vis'):
image_batch = images[camera_name]
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
joint_features = torch.cat(features, dim=-1)
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
time_marker = time_marker.expand(batch_size, -1, -1)
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
class _StubMultiTokenVisionBackbone(nn.Module):
output_dim = 2
tokens_per_step = 3
def __init__(self, camera_names=_CAMERA_NAMES):
super().__init__()
self.camera_names = tuple(camera_names)
self.num_cameras = len(self.camera_names)
def forward(self, images):
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
features = []
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
for camera_name in self.camera_names:
image_batch = images[camera_name]
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
return torch.stack(features, dim=2)
class _StubMultiTokenVisionBackbone(nn.Module):
output_dim = 2
tokens_per_step = 3
def __init__(self, camera_names=_CAMERA_NAMES):
super().__init__()
self.camera_names = tuple(camera_names)
self.num_cameras = len(self.camera_names)
def forward(self, images):
per_camera = []
for camera_name in self.camera_names:
image_batch = images[camera_name]
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
return torch.stack(per_camera, dim=2)
class _RecordingLinearIMFHead(nn.Module):
def __init__(self):
super().__init__()
@@ -390,6 +530,178 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
self.assertEqual(mock_predict_chunk.call_count, 2)
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
agent_cls, _agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
vision_backbone = _StubJointVisionBackbone()
agent = agent_cls(
vision_backbone=vision_backbone,
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
)
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
self.assertEqual(
agent.global_cond_dim,
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(torch, 'randn', return_value=initial_noise):
predicted_actions = agent.predict_action(images, qpos)
self.assertEqual(predicted_actions.shape, (1, 3, 2))
self.assertEqual(len(head.calls), 1)
expected_cond = torch.tensor(
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
dtype=torch.float32,
)
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
agent_cls, _agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
projector = nn.Linear(3, 4, bias=False)
with torch.no_grad():
projector.weight.copy_(
torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
],
dtype=torch.float32,
)
)
agent = agent_cls(
vision_backbone=_StubMultiTokenVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
cond_projector=projector,
)
self.assertEqual(agent.condition_tokens_per_step, 3)
self.assertEqual(agent.condition_sequence_length, 6)
self.assertEqual(agent.per_step_cond_dim, 4)
self.assertEqual(agent.global_cond_dim, 24)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
cond = agent._build_cond(images, qpos)
expected = torch.tensor(
[
[
[10.0, 10.5, 1.0, 11.0],
[20.0, 20.5, 1.0, 21.0],
[30.0, 30.5, 1.0, 31.0],
[10.0, 10.5, 2.0, 12.0],
[20.0, 20.5, 2.0, 22.0],
[30.0, 30.5, 2.0, 32.0],
]
],
dtype=torch.float32,
)
self.assertEqual(cond.shape, (1, 6, 4))
self.assertTrue(torch.allclose(cond, expected))
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
cond_projector = nn.Linear(3, 4, bias=False)
with torch.no_grad():
cond_projector.weight.copy_(torch.tensor([
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
], dtype=torch.float32))
agent = agent_cls(
vision_backbone=_StubMultiTokenVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
cond_projector=cond_projector,
)
agent.infer_scheduler = _ForbiddenScheduler()
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
predicted_actions = agent.predict_action(images, qpos)
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
[20.0, 20.5, 1.0, 21.0],
[30.0, 30.5, 1.0, 31.0],
[10.0, 10.5, 2.0, 12.0],
[20.0, 20.5, 2.0, 22.0],
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
self.assertEqual(agent.condition_tokens_per_step, 3)
self.assertEqual(agent.condition_sequence_length, 6)
self.assertEqual(agent.raw_per_step_cond_dim, 3)
self.assertEqual(agent.per_step_cond_dim, 4)
self.assertEqual(agent.global_cond_dim, 24)
self.assertEqual(predicted_actions.shape, (1, 3, 2))
self.assertEqual(len(head.calls), 1)
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
cfg = _compose_cfg(
overrides=[
@@ -448,6 +760,130 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_imf_attnres',
'agent.vision_backbone.checkpoint_path=null',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
self.assertEqual(cfg.agent.head.cond_dim, 208)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
self.assertEqual(agent.per_step_cond_dim, 208)
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.head.n_layer=1',
'agent.head.n_emb=32',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
self.assertEqual(cfg.agent.head.cond_dim, 32)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.condition_tokens_per_step, 3)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
self.assertEqual(agent.per_step_cond_dim, 32)
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
cfg = _compose_cfg(
overrides=[
'agent=siglip2_imf_attnres',
'agent.vision_backbone.per_view_output_dim=96',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.cond_projector.output_dim=384',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(
cfg.agent.vision_backbone._target_,
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
)
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
self.assertEqual(cfg.agent.head.cond_dim, 384)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
self.assertEqual(agent.per_step_cond_dim, 384)
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
self.assertEqual(agent.vision_encoder.output_dim, 96)
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.vision_backbone.freeze_backbone=false',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
self.assertEqual(cfg.agent.head.cond_dim, 16)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.condition_tokens_per_step, 3)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
self.assertEqual(agent.per_step_cond_dim, 16)
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,220 @@
import tempfile
import types
import unittest
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTConfig, ViTModel
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
class _ReferenceProjector(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(192, 2048),
nn.BatchNorm1d(2048),
nn.GELU(),
nn.Linear(2048, 192),
)
def forward(self, x):
return self.net(x)
def _build_reference_encoder() -> ViTModel:
return ViTModel(
ViTConfig(
image_size=224,
patch_size=14,
num_channels=3,
hidden_size=192,
intermediate_size=768,
num_hidden_layers=12,
num_attention_heads=3,
qkv_bias=True,
),
add_pooling_layer=False,
)
def _write_synthetic_lightning_ckpt(path: Path):
torch.manual_seed(7)
encoder = _build_reference_encoder()
projector = _ReferenceProjector()
lightning_state_dict = {}
for key, value in encoder.state_dict().items():
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
for key, value in projector.state_dict().items():
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
torch.save({"state_dict": lightning_state_dict}, path)
return encoder.state_dict(), projector.state_dict()
class LEWMViTBackboneTest(unittest.TestCase):
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
ckpt_path
)
backbone = LEWMViTBackbone(
checkpoint_path=ckpt_path,
camera_names=_INPUT_CAMERA_NAMES,
fused_camera_names=_FUSED_CAMERA_NAMES,
freeze_backbone=True,
)
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
self.assertEqual(backbone.num_cameras, 3)
self.assertEqual(backbone.joint_output_dim, 192)
self.assertEqual(backbone.output_dim, 192)
self.assertEqual(backbone.encoder.config.hidden_size, 192)
self.assertEqual(backbone.encoder.config.patch_size, 14)
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
for key, value in reference_encoder_state.items():
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
for key, value in reference_projector_state.items():
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
images = {
cam_name: torch.rand(1, 1, 3, 224, 224)
for cam_name in _INPUT_CAMERA_NAMES
}
output = backbone(images)
self.assertEqual(output.shape, (1, 1, 192))
self.assertFalse(output.requires_grad)
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
_write_synthetic_lightning_ckpt(ckpt_path)
backbone = LEWMViTBackbone(
checkpoint_path=ckpt_path,
camera_names=_INPUT_CAMERA_NAMES,
fused_camera_names=_FUSED_CAMERA_NAMES,
freeze_backbone=True,
)
captured = {}
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
del module, kwargs
captured["pixel_values"] = pixel_values.detach().clone()
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
batch = pixel_values.shape[0]
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
cls = (
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
.unsqueeze(0)
.expand(batch, -1)
)
last_hidden_state = torch.zeros(
batch,
patch_tokens + 1,
192,
dtype=pixel_values.dtype,
device=pixel_values.device,
)
last_hidden_state[:, 0] = cls
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
top = torch.full((1, 1, 3, 256, 256), 0.20)
front = torch.full((1, 1, 3, 256, 256), 0.10)
bn = backbone.projector.net[1]
running_mean_before = bn.running_mean.detach().clone()
running_var_before = bn.running_var.detach().clone()
backbone.train()
self.assertFalse(backbone.encoder.training)
self.assertFalse(backbone.projector.training)
output = backbone({"r_vis": r_vis, "top": top, "front": front})
self.assertEqual(output.shape, (1, 1, 192))
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
self.assertTrue(captured["interpolate_pos_encoding"])
normalized_views = [
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
for view in (front, top, r_vis)
]
expected_fuse_then_resize = F.interpolate(
torch.cat(normalized_views, dim=-2),
size=(672, 224),
mode="bilinear",
align_corners=False,
antialias=True,
)
expected_pre_resize_then_fuse = torch.cat(
[
F.interpolate(
view,
size=(224, 224),
mode="bilinear",
align_corners=False,
antialias=True,
)
for view in normalized_views
],
dim=-2,
)
self.assertTrue(
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
)
self.assertFalse(
torch.allclose(
expected_fuse_then_resize,
expected_pre_resize_then_fuse,
atol=1e-6,
rtol=1e-6,
)
)
self.assertFalse(
torch.allclose(
captured["pixel_values"],
expected_pre_resize_then_fuse,
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
captured["pixel_values"][0, :, 223, :],
expected_fuse_then_resize[0, :, 223, :],
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(
torch.allclose(
captured["pixel_values"][0, :, 447, :],
expected_fuse_then_resize[0, :, 447, :],
atol=1e-6,
rtol=1e-6,
)
)
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
self.assertTrue(torch.equal(bn.running_var, running_var_before))
if __name__ == "__main__":
unittest.main()

View File

@@ -180,6 +180,14 @@ def _extract_camera_markers(cond, feature_dim, num_cams):
return camera_block[:, 0]
def _extract_token_camera_markers(tokens):
return tokens[0, 0, :, 0]
def _extract_token_markers(token_sequence):
return token_sequence[0, 0, :, 0]
class ResNetTransformerAgentWiringTest(unittest.TestCase):
def test_hydra_wiring_uses_required_three_camera_transformer_conditioning_in_agent_order_and_ignores_extra_keys(self):
cfg = _compose_cfg(
@@ -246,6 +254,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
with self.assertRaisesRegex(ValueError, 'missing=.*top'):
agent.predict_action(missing_images, proprioception)
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
]
)
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
'left_wrist': 99.0,
},
)
tokens = backbone(images)
self.assertEqual(tokens.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
self.assertEqual(backbone.tokens_per_step, 3)
camera_markers = _extract_token_camera_markers(tokens)
self.assertTrue(torch.allclose(camera_markers, torch.tensor([10.0, 20.0, 30.0])))
def test_agent_rejects_conflicting_explicit_backbone_camera_names(self):
cfg = _compose_cfg(
overrides=[
@@ -382,6 +420,36 @@ class ResNetTransformerAgentWiringTest(unittest.TestCase):
with self.assertRaisesRegex(InstantiationException, 'num_cams'):
instantiate(cfg.agent)
def test_multitoken_resnet_backbone_emits_one_token_per_camera_in_agent_order(self):
cfg = _compose_cfg(
overrides=[
'agent=resnet_imf_attnres_multitoken',
'agent.vision_backbone.pretrained_backbone_weights=null',
'agent.vision_backbone.input_shape=[3,16,16]',
'agent.head.n_layer=1',
'agent.head.n_emb=32',
]
)
with _stub_optional_modules():
backbone = instantiate(cfg.agent.vision_backbone)
_patch_backbone_for_order_tracking(backbone)
images = _make_images(
batch_size=1,
obs_horizon=cfg.agent.obs_horizon,
image_shape=tuple(cfg.agent.vision_backbone.input_shape),
per_camera_fill={
'front': 30.0,
'top': 20.0,
'r_vis': 10.0,
},
)
output = backbone(images)
self.assertEqual(output.shape, (1, cfg.agent.obs_horizon, 3, backbone.output_dim))
token_markers = _extract_token_markers(output)
self.assertTrue(torch.allclose(token_markers, torch.tensor([10.0, 20.0, 30.0])))
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,121 @@
import types
import unittest
from unittest import mock
import torch
from torch import nn
_CAMERA_NAMES = ("r_vis", "top", "front")
class _FakeSiglipVisionOutput:
def __init__(self, pooler_output):
self.pooler_output = pooler_output
class _FakeSiglipVisionConfig:
def __init__(self, hidden_size=768, image_size=256):
self.hidden_size = hidden_size
self.image_size = image_size
class _FakeSiglipVisionModel(nn.Module):
def __init__(self, hidden_size=768):
super().__init__()
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
self.forward_calls = []
@classmethod
def from_pretrained(cls, *args, **kwargs):
del args, kwargs
return cls()
def forward(self, pixel_values=None, **kwargs):
self.forward_calls.append({
"pixel_values": pixel_values.detach().clone(),
"kwargs": dict(kwargs),
})
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
return _FakeSiglipVisionOutput(pooler_output=pooled)
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
fake_model = _FakeSiglipVisionModel(hidden_size=3)
with mock.patch(
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
return_value=fake_model,
) as mock_from_pretrained:
backbone = SigLIP2DiffusionBackbone(
model_name="google/siglip2-base-patch16-256",
camera_names=_CAMERA_NAMES,
num_cameras=3,
per_view_output_dim=2,
freeze_backbone=True,
)
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
self.assertEqual(backbone.num_cameras, 3)
self.assertEqual(backbone.output_dim, 2)
self.assertEqual(backbone.joint_output_dim, 6)
self.assertIsNone(backbone.dataset_image_resize_shape)
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
self.assertFalse(backbone.encoder.training)
with torch.no_grad():
backbone.view_projector.weight.zero_()
backbone.view_projector.bias.zero_()
backbone.view_projector.weight[0, 0] = 1.0
backbone.view_projector.weight[1, 1] = 1.0
images = {
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
"top": torch.full((1, 2, 3, 256, 256), 0.50),
"front": torch.full((1, 2, 3, 256, 256), 0.75),
}
output = backbone(images)
self.assertEqual(output.shape, (1, 2, 6))
self.assertEqual(len(fake_model.forward_calls), 3)
expected_per_camera = []
for cam_name in _CAMERA_NAMES:
img = images[cam_name].reshape(2, 3, 256, 256)
normalized = (img - 0.5) / 0.5
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
pixels = call["pixel_values"]
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
self.assertTrue(
torch.allclose(
pixels,
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
)
)
def test_forward_rejects_missing_required_camera(self):
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
backbone = SigLIP2DiffusionBackbone(
vision_model=_FakeSiglipVisionModel(hidden_size=4),
camera_names=_CAMERA_NAMES,
num_cameras=3,
)
with self.assertRaisesRegex(ValueError, "missing"):
backbone({
"r_vis": torch.rand(1, 1, 3, 256, 256),
"top": torch.rand(1, 1, 3, 256, 256),
})
if __name__ == "__main__":
unittest.main()

View File

@@ -56,3 +56,26 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
self.assertEqual(len(resize_calls), 2)
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
)
fake_cv2 = types.SimpleNamespace(
INTER_LINEAR=1,
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
)
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
sample = dataset[1]
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))

View File

@@ -159,6 +159,92 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
self.assertGreater(cfg.train.num_workers, 8)
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
cfg = OmegaConf.create(
{
'agent': {
'vision_backbone': {
'dataset_image_resize_shape': None,
},
'normalization_type': 'min_max',
},
'data': {
'dataset_dir': 'unused',
'camera_names': ['front'],
},
'train': {
'batch_size': 2,
'lr': 1e-4,
'max_steps': 0,
'device': 'cpu',
'disable_cudnn': False,
'num_workers': 0,
'val_split': 0.0,
'seed': 42,
'log_freq': 1,
'save_freq': 10,
'use_swanlab': False,
'rollout_val_freq_epochs': 0,
'rollout_validate_on_checkpoint': False,
'rollout_num_episodes': 1,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 1e-6,
'weight_decay': 1e-5,
'grad_clip': 1.0,
'pretrained_ckpt': None,
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 1,
'headless': True,
'device': 'cpu',
'verbose_action': False,
},
'experiment': {},
}
)
captured_dataset_kwargs = {}
def fake_instantiate(config_node, **kwargs):
if config_node is cfg.data:
captured_dataset_kwargs.update(kwargs)
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del shuffle, _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla, '_init_swanlab', return_value=None), \
mock.patch.object(train_vla, '_finish_swanlab', return_value=None), \
mock.patch.object(train_vla.torch, 'save', return_value=None):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertIn('image_resize_shape', captured_dataset_kwargs)
self.assertIsNone(captured_dataset_kwargs['image_resize_shape'])
def test_eval_main_delegates_to_plain_run_eval_helper(self):
cfg = OmegaConf.create(
{