import tempfile import types import unittest from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from transformers import ViTConfig, ViTModel _INPUT_CAMERA_NAMES = ("r_vis", "top", "front") _FUSED_CAMERA_NAMES = ("front", "top", "r_vis") class _ReferenceProjector(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(192, 2048), nn.BatchNorm1d(2048), nn.GELU(), nn.Linear(2048, 192), ) def forward(self, x): return self.net(x) def _build_reference_encoder() -> ViTModel: return ViTModel( ViTConfig( image_size=224, patch_size=14, num_channels=3, hidden_size=192, intermediate_size=768, num_hidden_layers=12, num_attention_heads=3, qkv_bias=True, ), add_pooling_layer=False, ) def _write_synthetic_lightning_ckpt(path: Path): torch.manual_seed(7) encoder = _build_reference_encoder() projector = _ReferenceProjector() lightning_state_dict = {} for key, value in encoder.state_dict().items(): lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone() for key, value in projector.state_dict().items(): lightning_state_dict[f"model.projector.{key}"] = value.detach().clone() torch.save({"state_dict": lightning_state_dict}, path) return encoder.state_dict(), projector.state_dict() class LEWMViTBackboneTest(unittest.TestCase): def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self): from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone with tempfile.TemporaryDirectory() as tmpdir: ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt" reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt( ckpt_path ) backbone = LEWMViTBackbone( checkpoint_path=ckpt_path, camera_names=_INPUT_CAMERA_NAMES, fused_camera_names=_FUSED_CAMERA_NAMES, freeze_backbone=True, ) self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES) self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES) self.assertEqual(backbone.num_cameras, 3) self.assertEqual(backbone.joint_output_dim, 192) self.assertEqual(backbone.output_dim, 192) self.assertEqual(backbone.encoder.config.hidden_size, 192) self.assertEqual(backbone.encoder.config.patch_size, 14) self.assertEqual(backbone.encoder.config.num_hidden_layers, 12) self.assertEqual(backbone.encoder.config.num_attention_heads, 3) for key, value in reference_encoder_state.items(): self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key) for key, value in reference_projector_state.items(): self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key) images = { cam_name: torch.rand(1, 1, 3, 224, 224) for cam_name in _INPUT_CAMERA_NAMES } output = backbone(images) self.assertEqual(output.shape, (1, 1, 192)) self.assertFalse(output.requires_grad) def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self): from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone with tempfile.TemporaryDirectory() as tmpdir: ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt" _write_synthetic_lightning_ckpt(ckpt_path) backbone = LEWMViTBackbone( checkpoint_path=ckpt_path, camera_names=_INPUT_CAMERA_NAMES, fused_camera_names=_FUSED_CAMERA_NAMES, freeze_backbone=True, ) captured = {} def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs): del module, kwargs captured["pixel_values"] = pixel_values.detach().clone() captured["interpolate_pos_encoding"] = interpolate_pos_encoding batch = pixel_values.shape[0] patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14) cls = ( torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device) .unsqueeze(0) .expand(batch, -1) ) last_hidden_state = torch.zeros( batch, patch_tokens + 1, 192, dtype=pixel_values.dtype, device=pixel_values.device, ) last_hidden_state[:, 0] = cls return types.SimpleNamespace(last_hidden_state=last_hidden_state) backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder) r_vis = torch.full((1, 1, 3, 256, 256), 0.30) top = torch.full((1, 1, 3, 256, 256), 0.20) front = torch.full((1, 1, 3, 256, 256), 0.10) bn = backbone.projector.net[1] running_mean_before = bn.running_mean.detach().clone() running_var_before = bn.running_var.detach().clone() backbone.train() self.assertFalse(backbone.encoder.training) self.assertFalse(backbone.projector.training) output = backbone({"r_vis": r_vis, "top": top, "front": front}) self.assertEqual(output.shape, (1, 1, 192)) self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224)) self.assertTrue(captured["interpolate_pos_encoding"]) normalized_views = [ ((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std for view in (front, top, r_vis) ] expected_fuse_then_resize = F.interpolate( torch.cat(normalized_views, dim=-2), size=(672, 224), mode="bilinear", align_corners=False, antialias=True, ) expected_pre_resize_then_fuse = torch.cat( [ F.interpolate( view, size=(224, 224), mode="bilinear", align_corners=False, antialias=True, ) for view in normalized_views ], dim=-2, ) self.assertTrue( torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6) ) self.assertFalse( torch.allclose( expected_fuse_then_resize, expected_pre_resize_then_fuse, atol=1e-6, rtol=1e-6, ) ) self.assertFalse( torch.allclose( captured["pixel_values"], expected_pre_resize_then_fuse, atol=1e-6, rtol=1e-6, ) ) self.assertTrue( torch.allclose( captured["pixel_values"][0, :, 223, :], expected_fuse_then_resize[0, :, 223, :], atol=1e-6, rtol=1e-6, ) ) self.assertTrue( torch.allclose( captured["pixel_values"][0, :, 447, :], expected_fuse_then_resize[0, :, 447, :], atol=1e-6, rtol=1e-6, ) ) self.assertTrue(torch.equal(bn.running_mean, running_mean_before)) self.assertTrue(torch.equal(bn.running_var, running_var_before)) if __name__ == "__main__": unittest.main()