122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
import types
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
_CAMERA_NAMES = ("r_vis", "top", "front")
|
|
|
|
|
|
class _FakeSiglipVisionOutput:
|
|
def __init__(self, pooler_output):
|
|
self.pooler_output = pooler_output
|
|
|
|
|
|
class _FakeSiglipVisionConfig:
|
|
def __init__(self, hidden_size=768, image_size=256):
|
|
self.hidden_size = hidden_size
|
|
self.image_size = image_size
|
|
|
|
|
|
class _FakeSiglipVisionModel(nn.Module):
|
|
def __init__(self, hidden_size=768):
|
|
super().__init__()
|
|
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
|
self.forward_calls = []
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, *args, **kwargs):
|
|
del args, kwargs
|
|
return cls()
|
|
|
|
def forward(self, pixel_values=None, **kwargs):
|
|
self.forward_calls.append({
|
|
"pixel_values": pixel_values.detach().clone(),
|
|
"kwargs": dict(kwargs),
|
|
})
|
|
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
|
|
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
|
|
|
|
|
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
|
|
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
|
|
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
|
|
|
fake_model = _FakeSiglipVisionModel(hidden_size=3)
|
|
with mock.patch(
|
|
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
|
|
return_value=fake_model,
|
|
) as mock_from_pretrained:
|
|
backbone = SigLIP2DiffusionBackbone(
|
|
model_name="google/siglip2-base-patch16-256",
|
|
camera_names=_CAMERA_NAMES,
|
|
num_cameras=3,
|
|
per_view_output_dim=2,
|
|
freeze_backbone=True,
|
|
)
|
|
|
|
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
|
|
self.assertEqual(backbone.num_cameras, 3)
|
|
self.assertEqual(backbone.output_dim, 2)
|
|
self.assertEqual(backbone.joint_output_dim, 6)
|
|
self.assertIsNone(backbone.dataset_image_resize_shape)
|
|
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
|
|
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
|
|
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
|
|
self.assertFalse(backbone.encoder.training)
|
|
|
|
with torch.no_grad():
|
|
backbone.view_projector.weight.zero_()
|
|
backbone.view_projector.bias.zero_()
|
|
backbone.view_projector.weight[0, 0] = 1.0
|
|
backbone.view_projector.weight[1, 1] = 1.0
|
|
|
|
images = {
|
|
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
|
|
"top": torch.full((1, 2, 3, 256, 256), 0.50),
|
|
"front": torch.full((1, 2, 3, 256, 256), 0.75),
|
|
}
|
|
output = backbone(images)
|
|
|
|
self.assertEqual(output.shape, (1, 2, 6))
|
|
self.assertEqual(len(fake_model.forward_calls), 3)
|
|
|
|
expected_per_camera = []
|
|
for cam_name in _CAMERA_NAMES:
|
|
img = images[cam_name].reshape(2, 3, 256, 256)
|
|
normalized = (img - 0.5) / 0.5
|
|
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
|
|
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
|
|
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
|
|
|
|
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
|
|
pixels = call["pixel_values"]
|
|
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
pixels,
|
|
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
|
|
)
|
|
)
|
|
|
|
def test_forward_rejects_missing_required_camera(self):
|
|
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
|
|
|
backbone = SigLIP2DiffusionBackbone(
|
|
vision_model=_FakeSiglipVisionModel(hidden_size=4),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_cameras=3,
|
|
)
|
|
|
|
with self.assertRaisesRegex(ValueError, "missing"):
|
|
backbone({
|
|
"r_vis": torch.rand(1, 1, 3, 256, 256),
|
|
"top": torch.rand(1, 1, 3, 256, 256),
|
|
})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|