Files
roboimi/tests/test_attnres_resnet2d_backbone.py
2026-04-05 00:07:59 +08:00

27 lines
751 B
Python

import unittest
import torch
class AttnResResNet2DBackboneTest(unittest.TestCase):
def test_backbone_preserves_resnet_like_stage_contract(self):
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
backbone = AttnResResNetLikeBackbone2D(
input_channels=3,
stem_dim=16,
stage_dims=(16, 32, 64, 128),
stage_depths=(1, 1, 1, 1),
stage_heads=(2, 4, 4, 8),
stage_kv_heads=(1, 1, 1, 1),
stage_window_sizes=(7, 7, 7, 7),
dropout=0.0,
)
x = torch.randn(2, 3, 56, 56)
y = backbone(x)
self.assertEqual(y.shape, (2, 128, 2, 2))
if __name__ == '__main__':
unittest.main()