27 lines
751 B
Python
27 lines
751 B
Python
import unittest
|
|
|
|
import torch
|
|
|
|
|
|
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
|
def test_backbone_preserves_resnet_like_stage_contract(self):
|
|
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
|
|
|
backbone = AttnResResNetLikeBackbone2D(
|
|
input_channels=3,
|
|
stem_dim=16,
|
|
stage_dims=(16, 32, 64, 128),
|
|
stage_depths=(1, 1, 1, 1),
|
|
stage_heads=(2, 4, 4, 8),
|
|
stage_kv_heads=(1, 1, 1, 1),
|
|
stage_window_sizes=(7, 7, 7, 7),
|
|
dropout=0.0,
|
|
)
|
|
x = torch.randn(2, 3, 56, 56)
|
|
y = backbone(x)
|
|
self.assertEqual(y.shape, (2, 128, 2, 2))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|