221 lines
8.1 KiB
Python
221 lines
8.1 KiB
Python
import tempfile
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers import ViTConfig, ViTModel
|
|
|
|
|
|
_INPUT_CAMERA_NAMES = ("r_vis", "top", "front")
|
|
_FUSED_CAMERA_NAMES = ("front", "top", "r_vis")
|
|
|
|
|
|
class _ReferenceProjector(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(192, 2048),
|
|
nn.BatchNorm1d(2048),
|
|
nn.GELU(),
|
|
nn.Linear(2048, 192),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
def _build_reference_encoder() -> ViTModel:
|
|
return ViTModel(
|
|
ViTConfig(
|
|
image_size=224,
|
|
patch_size=14,
|
|
num_channels=3,
|
|
hidden_size=192,
|
|
intermediate_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=3,
|
|
qkv_bias=True,
|
|
),
|
|
add_pooling_layer=False,
|
|
)
|
|
|
|
|
|
def _write_synthetic_lightning_ckpt(path: Path):
|
|
torch.manual_seed(7)
|
|
encoder = _build_reference_encoder()
|
|
projector = _ReferenceProjector()
|
|
lightning_state_dict = {}
|
|
for key, value in encoder.state_dict().items():
|
|
lightning_state_dict[f"model.encoder.{key}"] = value.detach().clone()
|
|
for key, value in projector.state_dict().items():
|
|
lightning_state_dict[f"model.projector.{key}"] = value.detach().clone()
|
|
torch.save({"state_dict": lightning_state_dict}, path)
|
|
return encoder.state_dict(), projector.state_dict()
|
|
|
|
|
|
class LEWMViTBackboneTest(unittest.TestCase):
|
|
def test_loads_lightning_encoder_and_projector_checkpoint_and_emits_joint_embedding(self):
|
|
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
|
reference_encoder_state, reference_projector_state = _write_synthetic_lightning_ckpt(
|
|
ckpt_path
|
|
)
|
|
|
|
backbone = LEWMViTBackbone(
|
|
checkpoint_path=ckpt_path,
|
|
camera_names=_INPUT_CAMERA_NAMES,
|
|
fused_camera_names=_FUSED_CAMERA_NAMES,
|
|
freeze_backbone=True,
|
|
)
|
|
|
|
self.assertEqual(backbone.camera_names, _INPUT_CAMERA_NAMES)
|
|
self.assertEqual(backbone.fused_camera_names, _FUSED_CAMERA_NAMES)
|
|
self.assertEqual(backbone.num_cameras, 3)
|
|
self.assertEqual(backbone.joint_output_dim, 192)
|
|
self.assertEqual(backbone.output_dim, 192)
|
|
self.assertEqual(backbone.encoder.config.hidden_size, 192)
|
|
self.assertEqual(backbone.encoder.config.patch_size, 14)
|
|
self.assertEqual(backbone.encoder.config.num_hidden_layers, 12)
|
|
self.assertEqual(backbone.encoder.config.num_attention_heads, 3)
|
|
|
|
for key, value in reference_encoder_state.items():
|
|
self.assertTrue(torch.equal(backbone.encoder.state_dict()[key], value), key)
|
|
for key, value in reference_projector_state.items():
|
|
self.assertTrue(torch.equal(backbone.projector.state_dict()[key], value), key)
|
|
|
|
images = {
|
|
cam_name: torch.rand(1, 1, 3, 224, 224)
|
|
for cam_name in _INPUT_CAMERA_NAMES
|
|
}
|
|
output = backbone(images)
|
|
|
|
self.assertEqual(output.shape, (1, 1, 192))
|
|
self.assertFalse(output.requires_grad)
|
|
|
|
def test_forward_uses_front_top_rvis_fusion_order_and_exact_lewm_cwh_resize_path(self):
|
|
from roboimi.vla.models.backbones.lewm_vit_backbone import LEWMViTBackbone
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
ckpt_path = Path(tmpdir) / "synthetic-lewm.ckpt"
|
|
_write_synthetic_lightning_ckpt(ckpt_path)
|
|
|
|
backbone = LEWMViTBackbone(
|
|
checkpoint_path=ckpt_path,
|
|
camera_names=_INPUT_CAMERA_NAMES,
|
|
fused_camera_names=_FUSED_CAMERA_NAMES,
|
|
freeze_backbone=True,
|
|
)
|
|
captured = {}
|
|
|
|
def fake_encoder_forward(module, pixel_values, interpolate_pos_encoding=False, **kwargs):
|
|
del module, kwargs
|
|
captured["pixel_values"] = pixel_values.detach().clone()
|
|
captured["interpolate_pos_encoding"] = interpolate_pos_encoding
|
|
batch = pixel_values.shape[0]
|
|
patch_tokens = (pixel_values.shape[-2] // 14) * (pixel_values.shape[-1] // 14)
|
|
cls = (
|
|
torch.arange(192, dtype=pixel_values.dtype, device=pixel_values.device)
|
|
.unsqueeze(0)
|
|
.expand(batch, -1)
|
|
)
|
|
last_hidden_state = torch.zeros(
|
|
batch,
|
|
patch_tokens + 1,
|
|
192,
|
|
dtype=pixel_values.dtype,
|
|
device=pixel_values.device,
|
|
)
|
|
last_hidden_state[:, 0] = cls
|
|
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
|
|
|
backbone.encoder.forward = types.MethodType(fake_encoder_forward, backbone.encoder)
|
|
|
|
r_vis = torch.full((1, 1, 3, 256, 256), 0.30)
|
|
top = torch.full((1, 1, 3, 256, 256), 0.20)
|
|
front = torch.full((1, 1, 3, 256, 256), 0.10)
|
|
bn = backbone.projector.net[1]
|
|
running_mean_before = bn.running_mean.detach().clone()
|
|
running_var_before = bn.running_var.detach().clone()
|
|
|
|
backbone.train()
|
|
self.assertFalse(backbone.encoder.training)
|
|
self.assertFalse(backbone.projector.training)
|
|
|
|
output = backbone({"r_vis": r_vis, "top": top, "front": front})
|
|
|
|
self.assertEqual(output.shape, (1, 1, 192))
|
|
self.assertEqual(captured["pixel_values"].shape, (1, 3, 672, 224))
|
|
self.assertTrue(captured["interpolate_pos_encoding"])
|
|
|
|
normalized_views = [
|
|
((view.reshape(-1, *view.shape[2:]).float()).clamp(0.0, 1.0) - backbone.mean) / backbone.std
|
|
for view in (front, top, r_vis)
|
|
]
|
|
expected_fuse_then_resize = F.interpolate(
|
|
torch.cat(normalized_views, dim=-2),
|
|
size=(672, 224),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
antialias=True,
|
|
)
|
|
expected_pre_resize_then_fuse = torch.cat(
|
|
[
|
|
F.interpolate(
|
|
view,
|
|
size=(224, 224),
|
|
mode="bilinear",
|
|
align_corners=False,
|
|
antialias=True,
|
|
)
|
|
for view in normalized_views
|
|
],
|
|
dim=-2,
|
|
)
|
|
|
|
self.assertTrue(
|
|
torch.allclose(captured["pixel_values"], expected_fuse_then_resize, atol=1e-6, rtol=1e-6)
|
|
)
|
|
self.assertFalse(
|
|
torch.allclose(
|
|
expected_fuse_then_resize,
|
|
expected_pre_resize_then_fuse,
|
|
atol=1e-6,
|
|
rtol=1e-6,
|
|
)
|
|
)
|
|
self.assertFalse(
|
|
torch.allclose(
|
|
captured["pixel_values"],
|
|
expected_pre_resize_then_fuse,
|
|
atol=1e-6,
|
|
rtol=1e-6,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
captured["pixel_values"][0, :, 223, :],
|
|
expected_fuse_then_resize[0, :, 223, :],
|
|
atol=1e-6,
|
|
rtol=1e-6,
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
captured["pixel_values"][0, :, 447, :],
|
|
expected_fuse_then_resize[0, :, 447, :],
|
|
atol=1e-6,
|
|
rtol=1e-6,
|
|
)
|
|
)
|
|
self.assertTrue(torch.equal(bn.running_mean, running_mean_before))
|
|
self.assertTrue(torch.equal(bn.running_var, running_var_before))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|