feat: add vision transfer backbones and IMF variants
This commit is contained in:
220
tests/test_lewm_vit_backbone.py
Normal file
220
tests/test_lewm_vit_backbone.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import ViTConfig, ViTModel
|
||||
|
||||
|
||||
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
|
||||
|
||||
|
||||
class _ReferenceProjector(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(192, 2048),
|
||||
nn.BatchNorm1d(2048),
|
||||
nn.GELU(),
|
||||
nn.Linear(2048, 192),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def _build_reference_encoder() -> ViTModel:
|
||||
return ViTModel(
|
||||
ViTConfig(
|
||||
image_size=224,
|
||||
patch_size=14,
|
||||
num_channels=3,
|
||||
hidden_size=192,
|
||||
intermediate_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=3,
|
||||
qkv_bias=True,
|
||||
),
|
||||
add_pooling_layer=False,
|
||||
)
|
||||
|
||||
|
||||
def _write_synthetic_lightning_ckpt(path: Path):
|
||||
torch.manual_seed(7)
|
||||
encoder = _build_reference_encoder()
|
||||
projector = _ReferenceProjector()
|
||||
lightning_state_dict = {}
|
||||
for key, value in encoder.state_dict().items():
|
||||
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
|
||||
for key, value in projector.state_dict().items():
|
||||
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
|
||||
torch.save({"state_dict": lightning_state_dict}, path)
|
||||
return encoder.state_dict(), projector.state_dict()
|
||||
|
||||
|
||||
class LEWMViTBackboneTest(unittest.TestCase):
|
||||
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
|
||||
ckpt_path
|
||||
)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.joint_output_dim, 192)
|
||||
self.assertEqual(backbone.output_dim, 192)
|
||||
self.assertEqual(backbone.encoder.config.hidden_size, 192)
|
||||
self.assertEqual(backbone.encoder.config.patch_size, 14)
|
||||
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
|
||||
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
|
||||
|
||||
for key, value in reference_encoder_state.items():
|
||||
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
|
||||
for key, value in reference_projector_state.items():
|
||||
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
|
||||
|
||||
images = {
|
||||
cam_name: torch.rand(1, 1, 3, 224, 224)
|
||||
for cam_name in _INPUT_CAMERA_NAMES
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertFalse(output.requires_grad)
|
||||
|
||||
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
|
||||
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
||||
_write_synthetic_lightning_ckpt(ckpt_path)
|
||||
|
||||
backbone = LEWMViTBackbone(
|
||||
checkpoint_path=ckpt_path,
|
||||
camera_names=_INPUT_CAMERA_NAMES,
|
||||
fused_camera_names=_FUSED_CAMERA_NAMES,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
captured = {}
|
||||
|
||||
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
|
||||
del module, kwargs
|
||||
captured["pixel_values"] = pixel_values.detach().clone()
|
||||
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
|
||||
batch = pixel_values.shape[0]
|
||||
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
|
||||
cls = (
|
||||
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
.unsqueeze(0)
|
||||
.expand(batch, -1)
|
||||
)
|
||||
last_hidden_state = torch.zeros(
|
||||
batch,
|
||||
patch_tokens + 1,
|
||||
192,
|
||||
dtype=pixel_values.dtype,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
last_hidden_state[:, 0] = cls
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
|
||||
|
||||
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
|
||||
top = torch.full((1, 1, 3, 256, 256), 0.20)
|
||||
front = torch.full((1, 1, 3, 256, 256), 0.10)
|
||||
bn = backbone.projector.net[1]
|
||||
running_mean_before = bn.running_mean.detach().clone()
|
||||
running_var_before = bn.running_var.detach().clone()
|
||||
|
||||
backbone.train()
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
self.assertFalse(backbone.projector.training)
|
||||
|
||||
output = backbone({"r_vis": r_vis, "top": top, "front": front})
|
||||
|
||||
self.assertEqual(output.shape, (1, 1, 192))
|
||||
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
|
||||
self.assertTrue(captured["interpolate_pos_encoding"])
|
||||
|
||||
normalized_views = [
|
||||
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
|
||||
for view in (front, top, r_vis)
|
||||
]
|
||||
expected_fuse_then_resize = F.interpolate(
|
||||
torch.cat(normalized_views, dim=-2),
|
||||
size=(672, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
expected_pre_resize_then_fuse = torch.cat(
|
||||
[
|
||||
F.interpolate(
|
||||
view,
|
||||
size=(224, 224),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
for view in normalized_views
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
expected_fuse_then_resize,
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(
|
||||
captured["pixel_values"],
|
||||
expected_pre_resize_then_fuse,
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 223, :],
|
||||
expected_fuse_then_resize[0, :, 223, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
captured["pixel_values"][0, :, 447, :],
|
||||
expected_fuse_then_resize[0, :, 447, :],
|
||||
atol=1e-6,
|
||||
rtol=1e-6,
|
||||
)
|
||||
)
|
||||
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
|
||||
self.assertTrue(torch.equal(bn.running_var, running_var_before))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user