feat: add vision transfer backbones and IMF variants
This commit is contained in:
121
tests/test_siglip2_diffusion_backbone.py
Normal file
121
tests/test_siglip2_diffusion_backbone.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import types
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
_CAMERA_NAMES = ("r_vis", "top", "front")
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
del args, kwargs
|
||||
return cls()
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
"pixel_values": pixel_values.detach().clone(),
|
||||
"kwargs": dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False)
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class SigLIP2DiffusionBackboneTest(unittest.TestCase):
|
||||
def test_forward_encodes_each_view_independently_and_concatenates_projected_features(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
fake_model = _FakeSiglipVisionModel(hidden_size=3)
|
||||
with mock.patch(
|
||||
"roboimi.vla.models.backbones.siglip2_diffusion_backbone.SiglipVisionModel.from_pretrained",
|
||||
return_value=fake_model,
|
||||
) as mock_from_pretrained:
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
model_name="google/siglip2-base-patch16-256",
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
per_view_output_dim=2,
|
||||
freeze_backbone=True,
|
||||
)
|
||||
|
||||
self.assertEqual(backbone.camera_names, _CAMERA_NAMES)
|
||||
self.assertEqual(backbone.num_cameras, 3)
|
||||
self.assertEqual(backbone.output_dim, 2)
|
||||
self.assertEqual(backbone.joint_output_dim, 6)
|
||||
self.assertIsNone(backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(backbone.eval_image_resize_shape, (256, 256))
|
||||
mock_from_pretrained.assert_called_once_with("google/siglip2-base-patch16-256")
|
||||
self.assertTrue(all(not p.requires_grad for p in backbone.encoder.parameters()))
|
||||
self.assertFalse(backbone.encoder.training)
|
||||
|
||||
with torch.no_grad():
|
||||
backbone.view_projector.weight.zero_()
|
||||
backbone.view_projector.bias.zero_()
|
||||
backbone.view_projector.weight[0, 0] = 1.0
|
||||
backbone.view_projector.weight[1, 1] = 1.0
|
||||
|
||||
images = {
|
||||
"r_vis": torch.full((1, 2, 3, 256, 256), 0.25),
|
||||
"top": torch.full((1, 2, 3, 256, 256), 0.50),
|
||||
"front": torch.full((1, 2, 3, 256, 256), 0.75),
|
||||
}
|
||||
output = backbone(images)
|
||||
|
||||
self.assertEqual(output.shape, (1, 2, 6))
|
||||
self.assertEqual(len(fake_model.forward_calls), 3)
|
||||
|
||||
expected_per_camera = []
|
||||
for cam_name in _CAMERA_NAMES:
|
||||
img = images[cam_name].reshape(2, 3, 256, 256)
|
||||
normalized = (img - 0.5) / 0.5
|
||||
expected_per_camera.append(normalized.mean(dim=(2, 3))[:, :2])
|
||||
expected = torch.cat(expected_per_camera, dim=-1).view(1, 2, 6)
|
||||
self.assertTrue(torch.allclose(output, expected, atol=1e-6, rtol=1e-6))
|
||||
|
||||
for call, cam_name in zip(fake_model.forward_calls, _CAMERA_NAMES):
|
||||
pixels = call["pixel_values"]
|
||||
self.assertEqual(tuple(pixels.shape), (2, 3, 256, 256))
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
pixels,
|
||||
(images[cam_name].reshape(2, 3, 256, 256) - 0.5) / 0.5,
|
||||
)
|
||||
)
|
||||
|
||||
def test_forward_rejects_missing_required_camera(self):
|
||||
from roboimi.vla.models.backbones.siglip2_diffusion_backbone import SigLIP2DiffusionBackbone
|
||||
|
||||
backbone = SigLIP2DiffusionBackbone(
|
||||
vision_model=_FakeSiglipVisionModel(hidden_size=4),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_cameras=3,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "missing"):
|
||||
backbone({
|
||||
"r_vis": torch.rand(1, 1, 3, 256, 256),
|
||||
"top": torch.rand(1, 1, 3, 256, 256),
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user