feat: add vision transfer backbones and IMF variants
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import importlib.machinery
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
@@ -69,6 +70,68 @@ class _FakeRearrange(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class _FakeViTConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class _FakeViTModel(nn.Module):
|
||||
def __init__(self, config, add_pooling_layer=False):
|
||||
super().__init__()
|
||||
del add_pooling_layer
|
||||
self.config = config
|
||||
hidden_size = int(getattr(config, 'hidden_size', 192))
|
||||
self.proj = nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
|
||||
del interpolate_pos_encoding, kwargs
|
||||
batch_size = pixel_values.shape[0]
|
||||
hidden_size = int(getattr(self.config, 'hidden_size', 192))
|
||||
seq_len = 2
|
||||
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
|
||||
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
||||
|
||||
|
||||
class _FakeSiglipVisionOutput:
|
||||
def __init__(self, pooler_output):
|
||||
self.pooler_output = pooler_output
|
||||
|
||||
|
||||
class _FakeSiglipVisionConfig:
|
||||
def __init__(self, hidden_size=768, image_size=256):
|
||||
self.hidden_size = hidden_size
|
||||
self.image_size = image_size
|
||||
|
||||
|
||||
class _FakeSiglipVisionModel(nn.Module):
|
||||
load_calls = []
|
||||
|
||||
def __init__(self, hidden_size=768):
|
||||
super().__init__()
|
||||
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
||||
self.scale = nn.Parameter(torch.tensor(1.0))
|
||||
self.forward_calls = []
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
||||
model = cls()
|
||||
cls.load_calls.append({
|
||||
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
})
|
||||
return model
|
||||
|
||||
def forward(self, pixel_values=None, **kwargs):
|
||||
self.forward_calls.append({
|
||||
'pixel_values': pixel_values.detach().clone(),
|
||||
'kwargs': dict(kwargs),
|
||||
})
|
||||
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
|
||||
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
||||
|
||||
|
||||
class _StubIMFHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -105,6 +168,11 @@ class _StubIMFHead(nn.Module):
|
||||
def _stub_optional_modules(include_imf_head=False):
|
||||
previous_modules = {}
|
||||
|
||||
def remember_and_remove(name):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
def inject(name, module):
|
||||
if name not in previous_modules:
|
||||
previous_modules[name] = sys.modules.get(name, _MISSING)
|
||||
@@ -125,6 +193,9 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
torchvision_module = types.ModuleType('torchvision')
|
||||
models_module = types.ModuleType('torchvision.models')
|
||||
transforms_module = types.ModuleType('torchvision.transforms')
|
||||
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
|
||||
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
|
||||
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
|
||||
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
||||
transforms_module.CenterCrop = _IdentityCrop
|
||||
transforms_module.RandomCrop = _IdentityCrop
|
||||
@@ -139,7 +210,14 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
einops_module.layers = einops_layers_module
|
||||
einops_layers_module.torch = einops_layers_torch_module
|
||||
|
||||
transformers_module = types.ModuleType('transformers')
|
||||
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
|
||||
transformers_module.ViTConfig = _FakeViTConfig
|
||||
transformers_module.ViTModel = _FakeViTModel
|
||||
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
|
||||
|
||||
try:
|
||||
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
|
||||
inject('diffusers', diffusers_module)
|
||||
inject('diffusers.schedulers', schedulers_module)
|
||||
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
||||
@@ -150,6 +228,7 @@ def _stub_optional_modules(include_imf_head=False):
|
||||
inject('einops', einops_module)
|
||||
inject('einops.layers', einops_layers_module)
|
||||
inject('einops.layers.torch', einops_layers_torch_module)
|
||||
inject('transformers', transformers_module)
|
||||
|
||||
if include_imf_head:
|
||||
import roboimi.vla.models.heads as heads_package
|
||||
@@ -200,6 +279,67 @@ class _StubVisionBackbone(nn.Module):
|
||||
return torch.cat(per_camera_features, dim=-1)
|
||||
|
||||
|
||||
class _StubJointVisionBackbone(nn.Module):
|
||||
joint_output_dim = 5
|
||||
output_dim = 5
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
for camera_name in ('front', 'top', 'r_vis'):
|
||||
image_batch = images[camera_name]
|
||||
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
||||
joint_features = torch.cat(features, dim=-1)
|
||||
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
|
||||
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
|
||||
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
|
||||
time_marker = time_marker.expand(batch_size, -1, -1)
|
||||
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
||||
features = []
|
||||
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
|
||||
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
|
||||
return torch.stack(features, dim=2)
|
||||
|
||||
|
||||
class _StubMultiTokenVisionBackbone(nn.Module):
|
||||
output_dim = 2
|
||||
tokens_per_step = 3
|
||||
|
||||
def __init__(self, camera_names=_CAMERA_NAMES):
|
||||
super().__init__()
|
||||
self.camera_names = tuple(camera_names)
|
||||
self.num_cameras = len(self.camera_names)
|
||||
|
||||
def forward(self, images):
|
||||
per_camera = []
|
||||
for camera_name in self.camera_names:
|
||||
image_batch = images[camera_name]
|
||||
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
|
||||
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
|
||||
return torch.stack(per_camera, dim=2)
|
||||
|
||||
|
||||
class _RecordingLinearIMFHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -390,6 +530,178 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
||||
self.assertEqual(mock_predict_chunk.call_count, 2)
|
||||
|
||||
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
vision_backbone = _StubJointVisionBackbone()
|
||||
agent = agent_cls(
|
||||
vision_backbone=vision_backbone,
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(
|
||||
agent.global_cond_dim,
|
||||
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
expected_cond = torch.tensor(
|
||||
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
|
||||
agent_cls, _agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
projector.weight.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=projector,
|
||||
)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
cond = agent._build_cond(images, qpos)
|
||||
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0],
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
self.assertEqual(cond.shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(cond, expected))
|
||||
|
||||
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
cond_projector = nn.Linear(3, 4, bias=False)
|
||||
with torch.no_grad():
|
||||
cond_projector.weight.copy_(torch.tensor([
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
[1.0, 0.0, 1.0],
|
||||
], dtype=torch.float32))
|
||||
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubMultiTokenVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
cond_projector=cond_projector,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
predicted_actions = agent.predict_action(images, qpos)
|
||||
|
||||
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
|
||||
[20.0, 20.5, 1.0, 21.0],
|
||||
[30.0, 30.5, 1.0, 31.0],
|
||||
[10.0, 10.5, 2.0, 12.0],
|
||||
[20.0, 20.5, 2.0, 22.0],
|
||||
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, 6)
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(agent.global_cond_dim, 24)
|
||||
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
@@ -448,6 +760,130 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
||||
|
||||
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_imf_attnres',
|
||||
'agent.vision_backbone.checkpoint_path=null',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 208)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 208)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
|
||||
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=32',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 32)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 32)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=siglip2_imf_attnres',
|
||||
'agent.vision_backbone.per_view_output_dim=96',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.cond_projector.output_dim=384',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
|
||||
)
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
||||
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 384)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
|
||||
self.assertEqual(agent.per_step_cond_dim, 384)
|
||||
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=resnet_imf_attnres_multitoken',
|
||||
'agent.vision_backbone.pretrained_backbone_weights=null',
|
||||
'agent.vision_backbone.input_shape=[3,16,16]',
|
||||
'agent.vision_backbone.freeze_backbone=false',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
||||
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
||||
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
||||
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 16)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.condition_tokens_per_step, 3)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 16)
|
||||
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
|
||||
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
|
||||
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user