import unittest import torch class AttnResResNet2DBackboneTest(unittest.TestCase): def test_backbone_preserves_resnet_like_stage_contract(self): from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D backbone = AttnResResNetLikeBackbone2D( input_channels=3, stem_dim=16, stage_dims=(16, 32, 64, 128), stage_depths=(1, 1, 1, 1), stage_heads=(2, 4, 4, 8), stage_kv_heads=(1, 1, 1, 1), stage_window_sizes=(7, 7, 7, 7), dropout=0.0, ) x = torch.randn(2, 3, 56, 56) y = backbone(x) self.assertEqual(y.shape, (2, 128, 2, 2)) if __name__ == '__main__': unittest.main()