1312 lines
52 KiB
Python
1312 lines
52 KiB
Python
import contextlib
|
|
import importlib
|
|
import importlib.machinery
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import torch
|
|
from hydra import compose, initialize_config_dir
|
|
from hydra.core.global_hydra import GlobalHydra
|
|
from hydra.utils import instantiate
|
|
from omegaconf import OmegaConf
|
|
from torch import nn
|
|
|
|
|
|
_REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
_CONFIG_DIR = str((_REPO_ROOT / 'roboimi/vla/conf').resolve())
|
|
_MISSING = object()
|
|
_CAMERA_NAMES = ('r_vis', 'top', 'front')
|
|
|
|
|
|
class _FakeScheduler:
|
|
def __init__(self, num_train_timesteps=100, **kwargs):
|
|
self.config = types.SimpleNamespace(num_train_timesteps=num_train_timesteps)
|
|
self.timesteps = []
|
|
|
|
def add_noise(self, sample, noise, timestep):
|
|
return sample + noise
|
|
|
|
def set_timesteps(self, num_inference_steps):
|
|
self.timesteps = list(range(num_inference_steps - 1, -1, -1))
|
|
|
|
def step(self, noise_pred, timestep, sample):
|
|
return types.SimpleNamespace(prev_sample=sample)
|
|
|
|
|
|
class _IdentityCrop:
|
|
def __init__(self, size):
|
|
self.size = size
|
|
|
|
def __call__(self, x):
|
|
return x
|
|
|
|
|
|
class _FakeResNet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1)
|
|
self.relu1 = nn.ReLU()
|
|
self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2)
|
|
self.relu2 = nn.ReLU()
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(16, 16)
|
|
|
|
def forward(self, x):
|
|
x = self.relu1(self.conv1(x))
|
|
x = self.relu2(self.conv2(x))
|
|
x = self.avgpool(x)
|
|
x = torch.flatten(x, start_dim=1)
|
|
return self.fc(x)
|
|
|
|
|
|
class _FakeRearrange(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
|
|
class _FakeViTConfig:
|
|
def __init__(self, **kwargs):
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
|
|
class _FakeViTModel(nn.Module):
|
|
def __init__(self, config, add_pooling_layer=False):
|
|
super().__init__()
|
|
del add_pooling_layer
|
|
self.config = config
|
|
hidden_size = int(getattr(config, 'hidden_size', 192))
|
|
self.proj = nn.Linear(hidden_size, hidden_size)
|
|
|
|
def forward(self, pixel_values=None, interpolate_pos_encoding=False, **kwargs):
|
|
del interpolate_pos_encoding, kwargs
|
|
batch_size = pixel_values.shape[0]
|
|
hidden_size = int(getattr(self.config, 'hidden_size', 192))
|
|
seq_len = 2
|
|
last_hidden_state = torch.zeros(batch_size, seq_len, hidden_size, dtype=pixel_values.dtype, device=pixel_values.device)
|
|
return types.SimpleNamespace(last_hidden_state=last_hidden_state)
|
|
|
|
|
|
class _FakeSiglipVisionOutput:
|
|
def __init__(self, pooler_output):
|
|
self.pooler_output = pooler_output
|
|
|
|
|
|
class _FakeSiglipVisionConfig:
|
|
def __init__(self, hidden_size=768, image_size=256):
|
|
self.hidden_size = hidden_size
|
|
self.image_size = image_size
|
|
|
|
|
|
class _FakeSiglipVisionModel(nn.Module):
|
|
load_calls = []
|
|
|
|
def __init__(self, hidden_size=768):
|
|
super().__init__()
|
|
self.config = _FakeSiglipVisionConfig(hidden_size=hidden_size)
|
|
self.scale = nn.Parameter(torch.tensor(1.0))
|
|
self.forward_calls = []
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
|
model = cls()
|
|
cls.load_calls.append({
|
|
'pretrained_model_name_or_path': pretrained_model_name_or_path,
|
|
'args': args,
|
|
'kwargs': kwargs,
|
|
})
|
|
return model
|
|
|
|
def forward(self, pixel_values=None, **kwargs):
|
|
self.forward_calls.append({
|
|
'pixel_values': pixel_values.detach().clone(),
|
|
'kwargs': dict(kwargs),
|
|
})
|
|
pooled = pixel_values.mean(dim=(2, 3), keepdim=False) * self.scale
|
|
return _FakeSiglipVisionOutput(pooler_output=pooled)
|
|
|
|
|
|
class _StubIMFHead(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
output_dim,
|
|
horizon,
|
|
n_obs_steps,
|
|
cond_dim,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.constructor_kwargs = {
|
|
'input_dim': input_dim,
|
|
'output_dim': output_dim,
|
|
'horizon': horizon,
|
|
'n_obs_steps': n_obs_steps,
|
|
'cond_dim': cond_dim,
|
|
**kwargs,
|
|
}
|
|
self.proj = nn.Linear(input_dim, output_dim)
|
|
self.cond_obs_emb = nn.Linear(cond_dim, max(cond_dim, 1))
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
return torch.zeros_like(sample)
|
|
|
|
def get_optim_groups(self, weight_decay):
|
|
return [
|
|
{'params': [self.proj.weight], 'weight_decay': weight_decay},
|
|
{'params': [self.proj.bias, self.cond_obs_emb.weight, self.cond_obs_emb.bias], 'weight_decay': 0.0},
|
|
]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _stub_optional_modules(include_imf_head=False):
|
|
previous_modules = {}
|
|
|
|
def remember_and_remove(name):
|
|
if name not in previous_modules:
|
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
|
sys.modules.pop(name, None)
|
|
|
|
def inject(name, module):
|
|
if name not in previous_modules:
|
|
previous_modules[name] = sys.modules.get(name, _MISSING)
|
|
sys.modules[name] = module
|
|
|
|
diffusers_module = types.ModuleType('diffusers')
|
|
schedulers_module = types.ModuleType('diffusers.schedulers')
|
|
ddpm_module = types.ModuleType('diffusers.schedulers.scheduling_ddpm')
|
|
ddim_module = types.ModuleType('diffusers.schedulers.scheduling_ddim')
|
|
ddpm_module.DDPMScheduler = _FakeScheduler
|
|
ddim_module.DDIMScheduler = _FakeScheduler
|
|
diffusers_module.DDPMScheduler = _FakeScheduler
|
|
diffusers_module.DDIMScheduler = _FakeScheduler
|
|
diffusers_module.schedulers = schedulers_module
|
|
schedulers_module.scheduling_ddpm = ddpm_module
|
|
schedulers_module.scheduling_ddim = ddim_module
|
|
|
|
torchvision_module = types.ModuleType('torchvision')
|
|
models_module = types.ModuleType('torchvision.models')
|
|
transforms_module = types.ModuleType('torchvision.transforms')
|
|
torchvision_module.__spec__ = importlib.machinery.ModuleSpec('torchvision', loader=None)
|
|
models_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.models', loader=None)
|
|
transforms_module.__spec__ = importlib.machinery.ModuleSpec('torchvision.transforms', loader=None)
|
|
models_module.resnet18 = lambda weights=None: _FakeResNet()
|
|
transforms_module.CenterCrop = _IdentityCrop
|
|
transforms_module.RandomCrop = _IdentityCrop
|
|
torchvision_module.models = models_module
|
|
torchvision_module.transforms = transforms_module
|
|
|
|
einops_module = types.ModuleType('einops')
|
|
einops_module.rearrange = lambda x, *args, **kwargs: x
|
|
einops_layers_module = types.ModuleType('einops.layers')
|
|
einops_layers_torch_module = types.ModuleType('einops.layers.torch')
|
|
einops_layers_torch_module.Rearrange = _FakeRearrange
|
|
einops_module.layers = einops_layers_module
|
|
einops_layers_module.torch = einops_layers_torch_module
|
|
|
|
transformers_module = types.ModuleType('transformers')
|
|
transformers_module.__spec__ = importlib.machinery.ModuleSpec('transformers', loader=None)
|
|
transformers_module.ViTConfig = _FakeViTConfig
|
|
transformers_module.ViTModel = _FakeViTModel
|
|
transformers_module.SiglipVisionModel = _FakeSiglipVisionModel
|
|
|
|
try:
|
|
remember_and_remove('roboimi.vla.models.backbones.siglip2_diffusion_backbone')
|
|
inject('diffusers', diffusers_module)
|
|
inject('diffusers.schedulers', schedulers_module)
|
|
inject('diffusers.schedulers.scheduling_ddpm', ddpm_module)
|
|
inject('diffusers.schedulers.scheduling_ddim', ddim_module)
|
|
inject('torchvision', torchvision_module)
|
|
inject('torchvision.models', models_module)
|
|
inject('torchvision.transforms', transforms_module)
|
|
inject('einops', einops_module)
|
|
inject('einops.layers', einops_layers_module)
|
|
inject('einops.layers.torch', einops_layers_torch_module)
|
|
inject('transformers', transformers_module)
|
|
|
|
if include_imf_head:
|
|
import roboimi.vla.models.heads as heads_package
|
|
|
|
imf_head_module = types.ModuleType('roboimi.vla.models.heads.imf_transformer1d')
|
|
imf_head_module.IMFTransformer1D = _StubIMFHead
|
|
inject('roboimi.vla.models.heads.imf_transformer1d', imf_head_module)
|
|
setattr(heads_package, 'imf_transformer1d', imf_head_module)
|
|
|
|
yield
|
|
finally:
|
|
for name, previous in reversed(list(previous_modules.items())):
|
|
if previous is _MISSING:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = previous
|
|
|
|
|
|
def _compose_cfg(overrides=None):
|
|
if not OmegaConf.has_resolver('len'):
|
|
OmegaConf.register_new_resolver('len', lambda x: len(x))
|
|
|
|
GlobalHydra.instance().clear()
|
|
with initialize_config_dir(version_base=None, config_dir=_CONFIG_DIR):
|
|
return compose(config_name='config', overrides=list(overrides or []))
|
|
|
|
|
|
def _load_imf_agent_class():
|
|
with _stub_optional_modules():
|
|
sys.modules.pop('roboimi.vla.agent_imf', None)
|
|
module = importlib.import_module('roboimi.vla.agent_imf')
|
|
return module.IMFVLAAgent, module
|
|
|
|
|
|
class _StubVisionBackbone(nn.Module):
|
|
output_dim = 1
|
|
|
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
|
super().__init__()
|
|
self.camera_names = tuple(camera_names)
|
|
self.num_cameras = len(self.camera_names)
|
|
|
|
def forward(self, images):
|
|
per_camera_features = []
|
|
for camera_name in self.camera_names:
|
|
image_batch = images[camera_name]
|
|
per_camera_features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
|
return torch.cat(per_camera_features, dim=-1)
|
|
|
|
|
|
class _StubJointVisionBackbone(nn.Module):
|
|
joint_output_dim = 5
|
|
output_dim = 5
|
|
|
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
|
super().__init__()
|
|
self.camera_names = tuple(camera_names)
|
|
self.num_cameras = len(self.camera_names)
|
|
|
|
def forward(self, images):
|
|
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
|
features = []
|
|
for camera_name in ('front', 'top', 'r_vis'):
|
|
image_batch = images[camera_name]
|
|
features.append(image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1))
|
|
joint_features = torch.cat(features, dim=-1)
|
|
front_top_sum = joint_features[..., :2].sum(dim=-1, keepdim=True)
|
|
r_vis_minus_front = (joint_features[..., 2:] - joint_features[..., :1])
|
|
time_marker = torch.arange(obs_horizon, dtype=joint_features.dtype).view(1, obs_horizon, 1)
|
|
time_marker = time_marker.expand(batch_size, -1, -1)
|
|
return torch.cat([joint_features, front_top_sum, r_vis_minus_front + time_marker], dim=-1)
|
|
|
|
|
|
class _StubMultiTokenVisionBackbone(nn.Module):
|
|
output_dim = 2
|
|
tokens_per_step = 3
|
|
|
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
|
super().__init__()
|
|
self.camera_names = tuple(camera_names)
|
|
self.num_cameras = len(self.camera_names)
|
|
|
|
def forward(self, images):
|
|
batch_size, obs_horizon = next(iter(images.values())).shape[:2]
|
|
features = []
|
|
time_marker = torch.arange(obs_horizon, dtype=torch.float32).view(1, obs_horizon, 1).expand(batch_size, -1, -1)
|
|
for camera_name in self.camera_names:
|
|
image_batch = images[camera_name]
|
|
camera_marker = image_batch.mean(dim=(2, 3, 4), keepdim=False).unsqueeze(-1)
|
|
features.append(torch.cat([camera_marker, camera_marker + time_marker], dim=-1))
|
|
return torch.stack(features, dim=2)
|
|
|
|
|
|
class _StubMultiTokenVisionBackbone(nn.Module):
|
|
output_dim = 2
|
|
tokens_per_step = 3
|
|
|
|
def __init__(self, camera_names=_CAMERA_NAMES):
|
|
super().__init__()
|
|
self.camera_names = tuple(camera_names)
|
|
self.num_cameras = len(self.camera_names)
|
|
|
|
def forward(self, images):
|
|
per_camera = []
|
|
for camera_name in self.camera_names:
|
|
image_batch = images[camera_name]
|
|
base = image_batch.mean(dim=(2, 3, 4), keepdim=False)
|
|
per_camera.append(torch.stack([base, base + 0.5], dim=-1))
|
|
return torch.stack(per_camera, dim=2)
|
|
|
|
|
|
class _RecordingLinearIMFHead(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.scale = nn.Parameter(torch.tensor(0.5))
|
|
self.calls = []
|
|
|
|
@staticmethod
|
|
def _broadcast_batch_time(value, reference):
|
|
while value.ndim < reference.ndim:
|
|
value = value.unsqueeze(-1)
|
|
return value
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
record = {
|
|
'sample': sample.detach().clone(),
|
|
'r': r.detach().clone(),
|
|
't': t.detach().clone(),
|
|
'cond': None if cond is None else cond.detach().clone(),
|
|
}
|
|
self.calls.append(record)
|
|
cond_term = 0.0
|
|
if cond is not None:
|
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
|
r_b = self._broadcast_batch_time(r, sample)
|
|
t_b = self._broadcast_batch_time(t, sample)
|
|
return self.scale * sample + r_b + 2.0 * t_b + cond_term
|
|
|
|
|
|
class _ForbiddenScheduler:
|
|
def set_timesteps(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
|
raise AssertionError('IMF inference should not use DDIM scheduler set_timesteps')
|
|
|
|
def step(self, *args, **kwargs): # pragma: no cover - only runs on regression
|
|
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
|
|
|
|
|
class _StubFutureTokenPredictor(nn.Module):
|
|
def __init__(self, num_future_tokens=1):
|
|
super().__init__()
|
|
self.num_future_tokens = int(num_future_tokens)
|
|
self.calls = []
|
|
|
|
def forward(self, history_tokens):
|
|
self.calls.append(history_tokens.detach().clone())
|
|
summary = history_tokens.mean(dim=1, keepdim=True)
|
|
return summary.repeat(1, self.num_future_tokens, 1)
|
|
|
|
|
|
class _RecordingDirectFutureDecoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.scale = nn.Parameter(torch.tensor(0.5))
|
|
self.calls = []
|
|
|
|
def forward(self, sample, r, t, cond=None):
|
|
record = {
|
|
'sample': sample.detach().clone(),
|
|
'r': r.detach().clone(),
|
|
't': t.detach().clone(),
|
|
'cond': None if cond is None else cond.detach().clone(),
|
|
}
|
|
self.calls.append(record)
|
|
cond_term = 0.0
|
|
if cond is not None:
|
|
cond_term = cond.mean(dim=1, keepdim=True)
|
|
return self.scale * sample + cond_term
|
|
|
|
|
|
class _RecordingSigReg(nn.Module):
|
|
def __init__(self, value=0.5):
|
|
super().__init__()
|
|
self.value = float(value)
|
|
self.calls = []
|
|
|
|
def forward(self, embeddings):
|
|
self.calls.append(embeddings.detach().clone())
|
|
return embeddings.new_tensor(self.value)
|
|
|
|
|
|
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
|
return {
|
|
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
|
for name, value in per_camera_fill.items()
|
|
}
|
|
|
|
|
|
class IMFVLAAgentTest(unittest.TestCase):
|
|
def _make_agent(self, pred_horizon=3, obs_horizon=2, num_action_steps=2):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
agent = agent_cls(
|
|
vision_backbone=_StubVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=pred_horizon,
|
|
obs_horizon=obs_horizon,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=num_action_steps,
|
|
head_type='transformer',
|
|
)
|
|
return agent, head, agent_module
|
|
|
|
def test_compute_loss_matches_imf_objective_and_masks_padded_actions(self):
|
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
|
actions = torch.tensor(
|
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
|
dtype=torch.float32,
|
|
)
|
|
action_is_pad = torch.tensor([[False, False, True]])
|
|
noise = torch.tensor(
|
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
|
dtype=torch.float32,
|
|
)
|
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
|
loss = agent.compute_loss(
|
|
{
|
|
'images': images,
|
|
'qpos': qpos,
|
|
'action': actions,
|
|
'action_is_pad': action_is_pad,
|
|
}
|
|
)
|
|
|
|
cond = torch.tensor([[[1.0, 2.0, 3.0, 0.25], [1.0, 2.0, 3.0, 0.75]]], dtype=torch.float32)
|
|
cond_term = cond.mean(dim=(1, 2), keepdim=True)
|
|
t = t_sample
|
|
r = r_sample
|
|
z_t = (1 - t.view(1, 1, 1)) * actions + t.view(1, 1, 1) * noise
|
|
scale = head.scale.detach()
|
|
u = scale * z_t + r.view(1, 1, 1) + 2.0 * t.view(1, 1, 1) + cond_term
|
|
v = scale * z_t + 3.0 * t.view(1, 1, 1) + cond_term
|
|
du_dt = scale * v + 2.0
|
|
compound_velocity = u + (t - r).view(1, 1, 1) * du_dt
|
|
target = noise - actions
|
|
elementwise_loss = (compound_velocity - target) ** 2
|
|
mask = (~action_is_pad).unsqueeze(-1).to(elementwise_loss.dtype)
|
|
expected_loss = (elementwise_loss * mask).sum() / (mask.sum() * elementwise_loss.shape[-1])
|
|
|
|
self.assertAlmostEqual(loss.item(), expected_loss.item(), places=6)
|
|
self.assertEqual(len(head.calls), 2)
|
|
self.assertTrue(torch.allclose(head.calls[0]['r'], t_sample))
|
|
self.assertTrue(torch.allclose(head.calls[0]['t'], t_sample))
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], cond))
|
|
|
|
def test_predict_action_uses_one_step_imf_sampling_and_image_conditioning(self):
|
|
agent, head, agent_module = self._make_agent(pred_horizon=3, obs_horizon=2)
|
|
agent.infer_scheduler = _ForbiddenScheduler()
|
|
|
|
images = _make_images(
|
|
batch_size=2,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor(
|
|
[
|
|
[[1.0], [2.0]],
|
|
[[3.0], [4.0]],
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
initial_noise = torch.tensor(
|
|
[
|
|
[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]],
|
|
[[-1.0, 1.0], [2.0, -3.0], [0.5, 0.25]],
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
|
predicted_actions = agent.predict_action(images, qpos)
|
|
|
|
expected_cond = torch.tensor(
|
|
[
|
|
[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]],
|
|
[[10.0, 20.0, 30.0, 3.0], [10.0, 20.0, 30.0, 4.0]],
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
cond_term = expected_cond.mean(dim=(1, 2), keepdim=True)
|
|
expected_actions = 0.5 * initial_noise - 2.0 - cond_term
|
|
|
|
self.assertEqual(predicted_actions.shape, (2, 3, 2))
|
|
self.assertTrue(torch.allclose(predicted_actions, expected_actions))
|
|
self.assertEqual(len(head.calls), 1)
|
|
self.assertTrue(torch.allclose(head.calls[0]['r'], torch.zeros(2)))
|
|
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
|
|
|
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
|
agent = agent_cls(
|
|
vision_backbone=_StubVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
extra_condition_tokens=1,
|
|
lewm_history_horizon=3,
|
|
lewm_query_offsets=[8],
|
|
lewm_predictor=future_predictor,
|
|
lewm_pred_projector=nn.Identity(),
|
|
lewm_loss_weight=0.5,
|
|
)
|
|
agent.infer_scheduler = _ForbiddenScheduler()
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
|
lewm_images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=3,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
|
|
initial_noise = torch.tensor(
|
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
|
_ = agent.predict_action(
|
|
images,
|
|
qpos,
|
|
lewm_images=lewm_images,
|
|
lewm_proprioception=lewm_qpos,
|
|
)
|
|
|
|
expected_history = torch.tensor(
|
|
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
|
|
expected_cond = torch.cat([expected_history, expected_future], dim=1)
|
|
|
|
self.assertEqual(agent.condition_sequence_length, 3)
|
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
|
self.assertEqual(len(head.calls), 1)
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
|
self.assertEqual(len(future_predictor.calls), 1)
|
|
|
|
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
|
sigreg = _RecordingSigReg(value=0.75)
|
|
agent = agent_cls(
|
|
vision_backbone=_StubVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
extra_condition_tokens=1,
|
|
lewm_history_horizon=3,
|
|
lewm_query_offsets=[8],
|
|
lewm_predictor=future_predictor,
|
|
lewm_pred_projector=nn.Identity(),
|
|
lewm_sigreg=sigreg,
|
|
lewm_sigreg_weight=0.09,
|
|
lewm_loss_weight=0.25,
|
|
)
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
|
actions = torch.tensor(
|
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
|
dtype=torch.float32,
|
|
)
|
|
lewm_images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=3,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
|
lewm_future_images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=1,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
|
noise = torch.tensor(
|
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
|
dtype=torch.float32,
|
|
)
|
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
|
loss = agent.compute_loss(
|
|
{
|
|
'images': images,
|
|
'qpos': qpos,
|
|
'action': actions,
|
|
'lewm_images': lewm_images,
|
|
'lewm_qpos': lewm_qpos,
|
|
'lewm_future_images': lewm_future_images,
|
|
'lewm_future_qpos': lewm_future_qpos,
|
|
}
|
|
)
|
|
|
|
metrics = agent.get_last_loss_breakdown()
|
|
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
|
self.assertIn('action_loss', metrics)
|
|
self.assertIn('lewm_pred_loss', metrics)
|
|
self.assertIn('lewm_sigreg_loss', metrics)
|
|
self.assertIn('lewm_loss', metrics)
|
|
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
|
self.assertAlmostEqual(
|
|
metrics['lewm_loss'],
|
|
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
|
places=5,
|
|
)
|
|
self.assertAlmostEqual(
|
|
metrics['loss'],
|
|
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
|
|
places=5,
|
|
)
|
|
self.assertEqual(len(sigreg.calls), 1)
|
|
expected_lewm_history = torch.tensor(
|
|
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
|
|
dtype=torch.float32,
|
|
)
|
|
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
|
|
|
def test_predict_action_with_dual_decoder_keeps_action_condition_history_only(self):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
future_decoder = _RecordingDirectFutureDecoder()
|
|
agent = agent_cls(
|
|
vision_backbone=_StubVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
future_decoder=future_decoder,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
lewm_history_horizon=3,
|
|
lewm_query_offsets=[8],
|
|
lewm_loss_weight=1.0,
|
|
)
|
|
agent.infer_scheduler = _ForbiddenScheduler()
|
|
with torch.no_grad():
|
|
agent.future_query_tokens.copy_(torch.tensor([[[0.1, 0.2, 0.3, 0.4]]], dtype=torch.float32))
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
|
initial_noise = torch.tensor(
|
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
|
_ = agent.predict_action(images, qpos)
|
|
|
|
expected_history = torch.tensor(
|
|
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
self.assertEqual(len(head.calls), 1)
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_history))
|
|
self.assertEqual(len(future_decoder.calls), 0)
|
|
|
|
def test_compute_loss_with_dual_decoder_tracks_lewm_loss_breakdown(self):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
future_decoder = _RecordingDirectFutureDecoder()
|
|
sigreg = _RecordingSigReg(value=0.75)
|
|
agent = agent_cls(
|
|
vision_backbone=_StubVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
future_decoder=future_decoder,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
lewm_history_horizon=3,
|
|
lewm_query_offsets=[8],
|
|
lewm_sigreg=sigreg,
|
|
lewm_sigreg_weight=0.09,
|
|
lewm_loss_weight=1.0,
|
|
)
|
|
with torch.no_grad():
|
|
agent.future_query_tokens.copy_(torch.tensor([[[0.2, 0.4, 0.6, 0.8]]], dtype=torch.float32))
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
|
actions = torch.tensor(
|
|
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
|
dtype=torch.float32,
|
|
)
|
|
lewm_images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=3,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
|
lewm_future_images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=1,
|
|
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
|
)
|
|
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
|
noise = torch.tensor(
|
|
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
|
dtype=torch.float32,
|
|
)
|
|
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
|
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
|
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
|
loss = agent.compute_loss(
|
|
{
|
|
'images': images,
|
|
'qpos': qpos,
|
|
'action': actions,
|
|
'lewm_images': lewm_images,
|
|
'lewm_qpos': lewm_qpos,
|
|
'lewm_future_images': lewm_future_images,
|
|
'lewm_future_qpos': lewm_future_qpos,
|
|
}
|
|
)
|
|
|
|
metrics = agent.get_last_loss_breakdown()
|
|
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
|
self.assertEqual(len(head.calls), 2)
|
|
self.assertEqual(head.calls[0]['cond'].shape, (1, 2, 4))
|
|
self.assertEqual(len(future_decoder.calls), 1)
|
|
self.assertEqual(future_decoder.calls[0]['cond'].shape, (1, 3, 4))
|
|
self.assertAlmostEqual(
|
|
metrics['loss'],
|
|
metrics['action_loss'] + metrics['lewm_loss'],
|
|
places=5,
|
|
)
|
|
self.assertAlmostEqual(
|
|
metrics['lewm_loss'],
|
|
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
|
places=5,
|
|
)
|
|
self.assertGreater(metrics['lewm_pred_loss'], 0.0)
|
|
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
|
|
|
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
|
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
|
observation = {
|
|
'qpos': torch.tensor([0.25], dtype=torch.float32),
|
|
'images': {
|
|
'front': torch.full((1, 2, 2), 3.0, dtype=torch.float32),
|
|
'top': torch.full((1, 2, 2), 2.0, dtype=torch.float32),
|
|
'r_vis': torch.full((1, 2, 2), 1.0, dtype=torch.float32),
|
|
},
|
|
}
|
|
first_chunk = torch.tensor(
|
|
[[[10.0, 11.0], [12.0, 13.0], [14.0, 15.0], [16.0, 17.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
second_chunk = torch.tensor(
|
|
[[[20.0, 21.0], [22.0, 23.0], [24.0, 25.0], [26.0, 27.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
with mock.patch.object(agent, 'predict_action_chunk', side_effect=[first_chunk, second_chunk]) as mock_predict_chunk:
|
|
first_action = agent.select_action(observation)
|
|
second_action = agent.select_action(observation)
|
|
third_action = agent.select_action(observation)
|
|
|
|
self.assertTrue(torch.equal(first_action, first_chunk[0, 1]))
|
|
self.assertTrue(torch.equal(second_action, first_chunk[0, 2]))
|
|
self.assertTrue(torch.equal(third_action, second_chunk[0, 1]))
|
|
self.assertEqual(mock_predict_chunk.call_count, 2)
|
|
|
|
def test_joint_visual_backbone_uses_joint_output_dim_for_conditioning(self):
|
|
agent_cls, _agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
vision_backbone = _StubJointVisionBackbone()
|
|
agent = agent_cls(
|
|
vision_backbone=vision_backbone,
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
)
|
|
|
|
self.assertEqual(agent.per_step_cond_dim, vision_backbone.joint_output_dim + agent.obs_dim)
|
|
self.assertEqual(
|
|
agent.global_cond_dim,
|
|
vision_backbone.joint_output_dim * agent.obs_horizon + agent.obs_dim * agent.obs_horizon,
|
|
)
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
|
initial_noise = torch.tensor(
|
|
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
with mock.patch.object(torch, 'randn', return_value=initial_noise):
|
|
predicted_actions = agent.predict_action(images, qpos)
|
|
|
|
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
|
self.assertEqual(len(head.calls), 1)
|
|
expected_cond = torch.tensor(
|
|
[[[30.0, 20.0, 10.0, 50.0, -20.0, 1.0], [30.0, 20.0, 10.0, 50.0, -19.0, 2.0]]],
|
|
dtype=torch.float32,
|
|
)
|
|
self.assertEqual(head.calls[0]['cond'].shape[-1], 6)
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
|
|
|
def test_multitoken_visual_backbone_flattens_camera_tokens_and_projects_each_with_state(self):
|
|
agent_cls, _agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
projector = nn.Linear(3, 4, bias=False)
|
|
with torch.no_grad():
|
|
projector.weight.copy_(
|
|
torch.tensor(
|
|
[
|
|
[1.0, 0.0, 0.0],
|
|
[0.0, 1.0, 0.0],
|
|
[0.0, 0.0, 1.0],
|
|
[1.0, 0.0, 1.0],
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
)
|
|
agent = agent_cls(
|
|
vision_backbone=_StubMultiTokenVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
cond_projector=projector,
|
|
)
|
|
|
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
|
self.assertEqual(agent.condition_sequence_length, 6)
|
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
|
self.assertEqual(agent.global_cond_dim, 24)
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
|
cond = agent._build_cond(images, qpos)
|
|
|
|
expected = torch.tensor(
|
|
[
|
|
[
|
|
[10.0, 10.5, 1.0, 11.0],
|
|
[20.0, 20.5, 1.0, 21.0],
|
|
[30.0, 30.5, 1.0, 31.0],
|
|
[10.0, 10.5, 2.0, 12.0],
|
|
[20.0, 20.5, 2.0, 22.0],
|
|
[30.0, 30.5, 2.0, 32.0],
|
|
]
|
|
],
|
|
dtype=torch.float32,
|
|
)
|
|
self.assertEqual(cond.shape, (1, 6, 4))
|
|
self.assertTrue(torch.allclose(cond, expected))
|
|
|
|
def test_multi_token_visual_backbone_pairs_state_per_camera_and_flattens_condition_sequence(self):
|
|
agent_cls, agent_module = _load_imf_agent_class()
|
|
head = _RecordingLinearIMFHead()
|
|
cond_projector = nn.Linear(3, 4, bias=False)
|
|
with torch.no_grad():
|
|
cond_projector.weight.copy_(torch.tensor([
|
|
[1.0, 0.0, 0.0],
|
|
[0.0, 1.0, 0.0],
|
|
[0.0, 0.0, 1.0],
|
|
[1.0, 0.0, 1.0],
|
|
], dtype=torch.float32))
|
|
|
|
agent = agent_cls(
|
|
vision_backbone=_StubMultiTokenVisionBackbone(),
|
|
state_encoder=nn.Identity(),
|
|
action_encoder=nn.Identity(),
|
|
head=head,
|
|
action_dim=2,
|
|
obs_dim=1,
|
|
pred_horizon=3,
|
|
obs_horizon=2,
|
|
diffusion_steps=10,
|
|
inference_steps=1,
|
|
num_cams=len(_CAMERA_NAMES),
|
|
camera_names=_CAMERA_NAMES,
|
|
num_action_steps=2,
|
|
head_type='transformer',
|
|
cond_projector=cond_projector,
|
|
)
|
|
agent.infer_scheduler = _ForbiddenScheduler()
|
|
|
|
images = _make_images(
|
|
batch_size=1,
|
|
obs_horizon=2,
|
|
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
|
)
|
|
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
|
initial_noise = torch.tensor([[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]], dtype=torch.float32)
|
|
|
|
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
|
predicted_actions = agent.predict_action(images, qpos)
|
|
|
|
expected_cond = torch.tensor([[[10.0, 10.5, 1.0, 11.0],
|
|
[20.0, 20.5, 1.0, 21.0],
|
|
[30.0, 30.5, 1.0, 31.0],
|
|
[10.0, 10.5, 2.0, 12.0],
|
|
[20.0, 20.5, 2.0, 22.0],
|
|
[30.0, 30.5, 2.0, 32.0]]], dtype=torch.float32)
|
|
|
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
|
self.assertEqual(agent.condition_sequence_length, 6)
|
|
self.assertEqual(agent.raw_per_step_cond_dim, 3)
|
|
self.assertEqual(agent.per_step_cond_dim, 4)
|
|
self.assertEqual(agent.global_cond_dim, 24)
|
|
self.assertEqual(predicted_actions.shape, (1, 3, 2))
|
|
self.assertEqual(len(head.calls), 1)
|
|
self.assertEqual(head.calls[0]['cond'].shape, (1, 6, 4))
|
|
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
|
|
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_stub_head(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=resnet_imf_attnres',
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
'agent.vision_backbone.freeze_backbone=false',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(cfg.agent.head._target_, 'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D')
|
|
self.assertEqual(cfg.agent.head.backbone_type, 'attnres_full')
|
|
self.assertEqual(cfg.agent.head.n_head, 1)
|
|
self.assertEqual(cfg.agent.head.n_kv_head, 1)
|
|
self.assertEqual(cfg.agent.head.n_cond_layers, 0)
|
|
self.assertTrue(cfg.agent.head.time_as_cond)
|
|
self.assertFalse(cfg.agent.head.causal_attn)
|
|
self.assertEqual(cfg.agent.inference_steps, 1)
|
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.head_type, 'transformer')
|
|
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.output_dim * agent.num_cams + agent.obs_dim)
|
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['backbone_type'], 'attnres_full')
|
|
|
|
def test_hydra_config_instantiates_resnet_imf_attnres_with_full_attnres_vision_backbone(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=resnet_imf_attnres',
|
|
'agent.vision_backbone.vision_backbone_mode=attnres_resnet',
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,56,56]',
|
|
'agent.vision_backbone.freeze_backbone=false',
|
|
'agent.vision_backbone.attnres_stem_dim=16',
|
|
'agent.vision_backbone.attnres_stage_dims=[16,32,64,128]',
|
|
'agent.vision_backbone.attnres_stage_depths=[1,1,1,1]',
|
|
'agent.vision_backbone.attnres_stage_heads=[2,4,4,8]',
|
|
'agent.vision_backbone.attnres_stage_kv_heads=[1,1,1,1]',
|
|
'agent.vision_backbone.attnres_stage_window_sizes=[7,7,7,7]',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
]
|
|
)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.vision_encoder.output_dim, 64)
|
|
self.assertEqual(agent.per_step_cond_dim, 64 * agent.num_cams + agent.obs_dim)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], agent.per_step_cond_dim)
|
|
|
|
def test_hydra_config_instantiates_lewm_imf_attnres_with_joint_visual_condition_dim(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=lewm_imf_attnres',
|
|
'agent.vision_backbone.checkpoint_path=null',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(cfg.agent.vision_backbone._target_, 'roboimi.vla.models.backbones.lewm_vit_backbone.LEWMViTBackbone')
|
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
|
self.assertEqual(list(cfg.agent.vision_backbone.camera_names), list(_CAMERA_NAMES))
|
|
self.assertEqual(list(cfg.agent.vision_backbone.fused_camera_names), ['front', 'top', 'r_vis'])
|
|
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
|
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
|
self.assertEqual(cfg.agent.head.cond_dim, 208)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.per_step_cond_dim, agent.vision_encoder.joint_output_dim + agent.obs_dim)
|
|
self.assertEqual(agent.per_step_cond_dim, 208)
|
|
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 208)
|
|
self.assertIsNone(agent.vision_encoder.dataset_image_resize_shape)
|
|
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 208)
|
|
|
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_projected_camera_tokens(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=resnet_imf_attnres_multitoken',
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=32',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
|
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
|
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 32)
|
|
self.assertEqual(cfg.agent.head.cond_dim, 32)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
|
self.assertEqual(agent.per_step_cond_dim, 32)
|
|
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 32)
|
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 32)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], 6)
|
|
|
|
|
|
def test_hydra_config_instantiates_siglip2_imf_attnres_with_condition_projection(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=siglip2_imf_attnres',
|
|
'agent.vision_backbone.per_view_output_dim=96',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
'agent.cond_projector.output_dim=384',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(
|
|
cfg.agent.vision_backbone._target_,
|
|
'roboimi.vla.models.backbones.siglip2_diffusion_backbone.SigLIP2DiffusionBackbone',
|
|
)
|
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
|
self.assertIsNone(cfg.agent.vision_backbone.dataset_image_resize_shape)
|
|
self.assertEqual(list(cfg.agent.vision_backbone.eval_image_resize_shape), [256, 256])
|
|
self.assertEqual(cfg.agent.head.cond_dim, 384)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.raw_per_step_cond_dim, 3 * 96 + agent.obs_dim)
|
|
self.assertEqual(agent.per_step_cond_dim, 384)
|
|
self.assertEqual(agent.global_cond_dim, agent.obs_horizon * 384)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 384)
|
|
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
|
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
|
|
|
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=lewm_resnet_query_imf_attnres',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
'agent.lewm_query_offsets=[8]',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(
|
|
cfg.agent.vision_backbone._target_,
|
|
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
|
|
)
|
|
self.assertEqual(
|
|
cfg.agent.state_encoder._target_,
|
|
'roboimi.vla.modules.encoders.LeWMStateEncoder',
|
|
)
|
|
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
|
|
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
|
|
self.assertEqual(
|
|
cfg.agent.lewm_sigreg._target_,
|
|
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
|
|
)
|
|
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.per_step_cond_dim, 288)
|
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
|
|
self.assertEqual(
|
|
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
|
|
agent.condition_sequence_length,
|
|
)
|
|
self.assertIsNotNone(agent.lewm_sigreg)
|
|
|
|
def test_hydra_config_instantiates_lewm_resnet_dual_decoder_imf_attnres(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=lewm_resnet_dual_decoder_imf_attnres',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
'agent.future_decoder.n_layer=1',
|
|
'agent.future_decoder.n_emb=16',
|
|
'agent.lewm_query_offsets=[8]',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(cfg.agent.extra_condition_tokens, 0)
|
|
self.assertEqual(
|
|
cfg.agent.future_decoder._target_,
|
|
'roboimi.vla.models.heads.imf_transformer1d.IMFTransformer1D',
|
|
)
|
|
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
|
self.assertEqual(cfg.agent.future_decoder.cond_dim, 288)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.per_step_cond_dim, 288)
|
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.obs_horizon)
|
|
self.assertEqual(agent.future_decoder.constructor_kwargs['cond_dim'], 288)
|
|
self.assertEqual(
|
|
agent.future_decoder.constructor_kwargs['n_obs_steps'],
|
|
agent.lewm_history_horizon,
|
|
)
|
|
self.assertEqual(agent.future_query_tokens.shape, (1, 1, 288))
|
|
|
|
|
|
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
|
cfg = _compose_cfg(
|
|
overrides=[
|
|
'agent=resnet_imf_attnres_multitoken',
|
|
'agent.vision_backbone.pretrained_backbone_weights=null',
|
|
'agent.vision_backbone.input_shape=[3,16,16]',
|
|
'agent.vision_backbone.freeze_backbone=false',
|
|
'agent.head.n_layer=1',
|
|
'agent.head.n_emb=16',
|
|
]
|
|
)
|
|
|
|
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
|
self.assertEqual(list(cfg.agent.camera_names), list(_CAMERA_NAMES))
|
|
self.assertTrue(cfg.agent.vision_backbone.use_separate_rgb_encoder_per_camera)
|
|
self.assertTrue(cfg.agent.vision_backbone.output_tokens_per_camera)
|
|
self.assertEqual(cfg.agent.vision_backbone.vision_backbone_mode, 'resnet')
|
|
self.assertEqual(cfg.agent.cond_projector.output_dim, 16)
|
|
self.assertEqual(cfg.agent.head.cond_dim, 16)
|
|
|
|
with _stub_optional_modules(include_imf_head=True):
|
|
agent = instantiate(cfg.agent)
|
|
|
|
self.assertEqual(agent.condition_tokens_per_step, 3)
|
|
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon * 3)
|
|
self.assertEqual(agent.per_step_cond_dim, 16)
|
|
self.assertEqual(agent.global_cond_dim, agent.condition_sequence_length * 16)
|
|
self.assertEqual(agent.vision_encoder.tokens_per_step, 3)
|
|
self.assertIsInstance(agent.noise_pred_net, _StubIMFHead)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 16)
|
|
self.assertEqual(agent.noise_pred_net.constructor_kwargs['n_obs_steps'], agent.condition_sequence_length)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|