feat: add lewm-conditioned imf training and sigreg loss
This commit is contained in:
@@ -376,6 +376,29 @@ class _ForbiddenScheduler:
|
||||
raise AssertionError('IMF inference should not use DDIM scheduler step')
|
||||
|
||||
|
||||
class _StubFutureTokenPredictor(nn.Module):
|
||||
def __init__(self, num_future_tokens=1):
|
||||
super().__init__()
|
||||
self.num_future_tokens = int(num_future_tokens)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, history_tokens):
|
||||
self.calls.append(history_tokens.detach().clone())
|
||||
summary = history_tokens.mean(dim=1, keepdim=True)
|
||||
return summary.repeat(1, self.num_future_tokens, 1)
|
||||
|
||||
|
||||
class _RecordingSigReg(nn.Module):
|
||||
def __init__(self, value=0.5):
|
||||
super().__init__()
|
||||
self.value = float(value)
|
||||
self.calls = []
|
||||
|
||||
def forward(self, embeddings):
|
||||
self.calls.append(embeddings.detach().clone())
|
||||
return embeddings.new_tensor(self.value)
|
||||
|
||||
|
||||
def _make_images(batch_size, obs_horizon, per_camera_fill):
|
||||
return {
|
||||
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
|
||||
@@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
|
||||
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_loss_weight=0.5,
|
||||
)
|
||||
agent.infer_scheduler = _ForbiddenScheduler()
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
|
||||
initial_noise = torch.tensor(
|
||||
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
|
||||
_ = agent.predict_action(
|
||||
images,
|
||||
qpos,
|
||||
lewm_images=lewm_images,
|
||||
lewm_proprioception=lewm_qpos,
|
||||
)
|
||||
|
||||
expected_history = torch.tensor(
|
||||
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
|
||||
expected_cond = torch.cat([expected_history, expected_future], dim=1)
|
||||
|
||||
self.assertEqual(agent.condition_sequence_length, 3)
|
||||
self.assertEqual(agent.per_step_cond_dim, 4)
|
||||
self.assertEqual(len(head.calls), 1)
|
||||
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
|
||||
self.assertEqual(len(future_predictor.calls), 1)
|
||||
|
||||
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
|
||||
agent_cls, agent_module = _load_imf_agent_class()
|
||||
head = _RecordingLinearIMFHead()
|
||||
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
|
||||
sigreg = _RecordingSigReg(value=0.75)
|
||||
agent = agent_cls(
|
||||
vision_backbone=_StubVisionBackbone(),
|
||||
state_encoder=nn.Identity(),
|
||||
action_encoder=nn.Identity(),
|
||||
head=head,
|
||||
action_dim=2,
|
||||
obs_dim=1,
|
||||
pred_horizon=3,
|
||||
obs_horizon=2,
|
||||
diffusion_steps=10,
|
||||
inference_steps=1,
|
||||
num_cams=len(_CAMERA_NAMES),
|
||||
camera_names=_CAMERA_NAMES,
|
||||
num_action_steps=2,
|
||||
head_type='transformer',
|
||||
extra_condition_tokens=1,
|
||||
lewm_history_horizon=3,
|
||||
lewm_query_offsets=[8],
|
||||
lewm_predictor=future_predictor,
|
||||
lewm_pred_projector=nn.Identity(),
|
||||
lewm_sigreg=sigreg,
|
||||
lewm_sigreg_weight=0.09,
|
||||
lewm_loss_weight=0.25,
|
||||
)
|
||||
|
||||
images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=2,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
|
||||
actions = torch.tensor(
|
||||
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
lewm_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=3,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
|
||||
lewm_future_images = _make_images(
|
||||
batch_size=1,
|
||||
obs_horizon=1,
|
||||
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
|
||||
)
|
||||
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
|
||||
noise = torch.tensor(
|
||||
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
t_sample = torch.tensor([0.8], dtype=torch.float32)
|
||||
r_sample = torch.tensor([0.25], dtype=torch.float32)
|
||||
|
||||
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
|
||||
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
|
||||
loss = agent.compute_loss(
|
||||
{
|
||||
'images': images,
|
||||
'qpos': qpos,
|
||||
'action': actions,
|
||||
'lewm_images': lewm_images,
|
||||
'lewm_qpos': lewm_qpos,
|
||||
'lewm_future_images': lewm_future_images,
|
||||
'lewm_future_qpos': lewm_future_qpos,
|
||||
}
|
||||
)
|
||||
|
||||
metrics = agent.get_last_loss_breakdown()
|
||||
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
|
||||
self.assertIn('action_loss', metrics)
|
||||
self.assertIn('lewm_pred_loss', metrics)
|
||||
self.assertIn('lewm_sigreg_loss', metrics)
|
||||
self.assertIn('lewm_loss', metrics)
|
||||
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
|
||||
self.assertAlmostEqual(
|
||||
metrics['lewm_loss'],
|
||||
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
metrics['loss'],
|
||||
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
|
||||
places=5,
|
||||
)
|
||||
self.assertEqual(len(sigreg.calls), 1)
|
||||
expected_lewm_history = torch.tensor(
|
||||
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
|
||||
|
||||
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
|
||||
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
|
||||
observation = {
|
||||
@@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase):
|
||||
self.assertEqual(agent.vision_encoder.output_dim, 96)
|
||||
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
|
||||
|
||||
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
|
||||
cfg = _compose_cfg(
|
||||
overrides=[
|
||||
'agent=lewm_resnet_query_imf_attnres',
|
||||
'agent.head.n_layer=1',
|
||||
'agent.head.n_emb=16',
|
||||
'agent.lewm_query_offsets=[8]',
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
|
||||
self.assertEqual(
|
||||
cfg.agent.vision_backbone._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
|
||||
)
|
||||
self.assertEqual(
|
||||
cfg.agent.state_encoder._target_,
|
||||
'roboimi.vla.modules.encoders.LeWMStateEncoder',
|
||||
)
|
||||
self.assertEqual(cfg.agent.head.cond_dim, 288)
|
||||
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
|
||||
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
|
||||
self.assertEqual(
|
||||
cfg.agent.lewm_sigreg._target_,
|
||||
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
|
||||
)
|
||||
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
|
||||
|
||||
with _stub_optional_modules(include_imf_head=True):
|
||||
agent = instantiate(cfg.agent)
|
||||
|
||||
self.assertEqual(agent.per_step_cond_dim, 288)
|
||||
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
|
||||
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
|
||||
self.assertEqual(
|
||||
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
|
||||
agent.condition_sequence_length,
|
||||
)
|
||||
self.assertIsNotNone(agent.lewm_sigreg)
|
||||
|
||||
|
||||
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
|
||||
cfg = _compose_cfg(
|
||||
|
||||
Reference in New Issue
Block a user