feat: add full attnres vision backbone
This commit is contained in:
26
tests/test_attnres_resnet2d_backbone.py
Normal file
26
tests/test_attnres_resnet2d_backbone.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttnResResNet2DBackboneTest(unittest.TestCase):
|
||||
def test_backbone_preserves_resnet_like_stage_contract(self):
|
||||
from roboimi.vla.models.backbones.attnres_resnet2d import AttnResResNetLikeBackbone2D
|
||||
|
||||
backbone = AttnResResNetLikeBackbone2D(
|
||||
input_channels=3,
|
||||
stem_dim=16,
|
||||
stage_dims=(16, 32, 64, 128),
|
||||
stage_depths=(1, 1, 1, 1),
|
||||
stage_heads=(2, 4, 4, 8),
|
||||
stage_kv_heads=(1, 1, 1, 1),
|
||||
stage_window_sizes=(7, 7, 7, 7),
|
||||
dropout=0.0,
|
||||
)
|
||||
x = torch.randn(2, 3, 56, 56)
|
||||
y = backbone(x)
|
||||
self.assertEqual(y.shape, (2, 128, 2, 2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -422,6 +422,32 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres',
|
||||
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,56,56]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.vision_backbone.attnres_stem_dim=16',
|
||||
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
||||
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
||||
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
||||
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
||||
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user