feat: add lewm-conditioned imf training and sigreg loss

This commit is contained in:
Logic
2026-04-17 18:46:02 +08:00
parent ff7c9c1f2a
commit 74f4963613
14 changed files with 1634 additions and 24 deletions

View File

@@ -376,6 +376,29 @@ class _ForbiddenScheduler:
raise AssertionError('IMF inference should not use DDIM scheduler step')
class _StubFutureTokenPredictor(nn.Module):
def __init__(self, num_future_tokens=1):
super().__init__()
self.num_future_tokens = int(num_future_tokens)
self.calls = []
def forward(self, history_tokens):
self.calls.append(history_tokens.detach().clone())
summary = history_tokens.mean(dim=1, keepdim=True)
return summary.repeat(1, self.num_future_tokens, 1)
class _RecordingSigReg(nn.Module):
def __init__(self, value=0.5):
super().__init__()
self.value = float(value)
self.calls = []
def forward(self, embeddings):
self.calls.append(embeddings.detach().clone())
return embeddings.new_tensor(self.value)
def _make_images(batch_size, obs_horizon, per_camera_fill):
return {
name: torch.full((batch_size, obs_horizon, 1, 2, 2), fill_value=value, dtype=torch.float32)
@@ -501,6 +524,169 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertTrue(torch.allclose(head.calls[0]['t'], torch.ones(2)))
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
def test_predict_action_appends_lewm_future_tokens_to_history_conditioning(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_loss_weight=0.5,
)
agent.infer_scheduler = _ForbiddenScheduler()
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
qpos = torch.tensor([[[1.0], [2.0]]], dtype=torch.float32)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 10.0, 'top': 20.0, 'front': 30.0},
)
lewm_qpos = torch.tensor([[[0.5], [1.5], [2.5]]], dtype=torch.float32)
initial_noise = torch.tensor(
[[[1.0, -1.0], [0.0, 2.0], [3.0, -2.0]]],
dtype=torch.float32,
)
with mock.patch.object(agent_module.torch, 'randn', return_value=initial_noise):
_ = agent.predict_action(
images,
qpos,
lewm_images=lewm_images,
lewm_proprioception=lewm_qpos,
)
expected_history = torch.tensor(
[[[10.0, 20.0, 30.0, 1.0], [10.0, 20.0, 30.0, 2.0]]],
dtype=torch.float32,
)
expected_future = torch.tensor([[[10.0, 20.0, 30.0, 1.5]]], dtype=torch.float32)
expected_cond = torch.cat([expected_history, expected_future], dim=1)
self.assertEqual(agent.condition_sequence_length, 3)
self.assertEqual(agent.per_step_cond_dim, 4)
self.assertEqual(len(head.calls), 1)
self.assertTrue(torch.allclose(head.calls[0]['cond'], expected_cond))
self.assertEqual(len(future_predictor.calls), 1)
def test_compute_loss_tracks_action_and_lewm_loss_breakdown(self):
agent_cls, agent_module = _load_imf_agent_class()
head = _RecordingLinearIMFHead()
future_predictor = _StubFutureTokenPredictor(num_future_tokens=1)
sigreg = _RecordingSigReg(value=0.75)
agent = agent_cls(
vision_backbone=_StubVisionBackbone(),
state_encoder=nn.Identity(),
action_encoder=nn.Identity(),
head=head,
action_dim=2,
obs_dim=1,
pred_horizon=3,
obs_horizon=2,
diffusion_steps=10,
inference_steps=1,
num_cams=len(_CAMERA_NAMES),
camera_names=_CAMERA_NAMES,
num_action_steps=2,
head_type='transformer',
extra_condition_tokens=1,
lewm_history_horizon=3,
lewm_query_offsets=[8],
lewm_predictor=future_predictor,
lewm_pred_projector=nn.Identity(),
lewm_sigreg=sigreg,
lewm_sigreg_weight=0.09,
lewm_loss_weight=0.25,
)
images = _make_images(
batch_size=1,
obs_horizon=2,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
qpos = torch.tensor([[[0.25], [0.75]]], dtype=torch.float32)
actions = torch.tensor(
[[[1.0, -1.0], [0.5, 0.25], [-0.5, 1.5]]],
dtype=torch.float32,
)
lewm_images = _make_images(
batch_size=1,
obs_horizon=3,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_qpos = torch.tensor([[[0.1], [0.2], [0.3]]], dtype=torch.float32)
lewm_future_images = _make_images(
batch_size=1,
obs_horizon=1,
per_camera_fill={'r_vis': 1.0, 'top': 2.0, 'front': 3.0},
)
lewm_future_qpos = torch.tensor([[[0.4]]], dtype=torch.float32)
noise = torch.tensor(
[[[0.2, -0.4], [0.1, 0.3], [0.5, -0.2]]],
dtype=torch.float32,
)
t_sample = torch.tensor([0.8], dtype=torch.float32)
r_sample = torch.tensor([0.25], dtype=torch.float32)
with mock.patch.object(agent_module.torch, 'randn_like', return_value=noise), \
mock.patch.object(agent_module.torch, 'rand', side_effect=[t_sample, r_sample]):
loss = agent.compute_loss(
{
'images': images,
'qpos': qpos,
'action': actions,
'lewm_images': lewm_images,
'lewm_qpos': lewm_qpos,
'lewm_future_images': lewm_future_images,
'lewm_future_qpos': lewm_future_qpos,
}
)
metrics = agent.get_last_loss_breakdown()
self.assertAlmostEqual(loss.item(), metrics['loss'], places=6)
self.assertIn('action_loss', metrics)
self.assertIn('lewm_pred_loss', metrics)
self.assertIn('lewm_sigreg_loss', metrics)
self.assertIn('lewm_loss', metrics)
self.assertAlmostEqual(metrics['lewm_sigreg_loss'], 0.75, places=6)
self.assertAlmostEqual(
metrics['lewm_loss'],
metrics['lewm_pred_loss'] + 0.09 * metrics['lewm_sigreg_loss'],
places=5,
)
self.assertAlmostEqual(
metrics['loss'],
metrics['action_loss'] + 0.25 * metrics['lewm_loss'],
places=5,
)
self.assertEqual(len(sigreg.calls), 1)
expected_lewm_history = torch.tensor(
[[[1.0, 2.0, 3.0, 0.1], [1.0, 2.0, 3.0, 0.2], [1.0, 2.0, 3.0, 0.3]]],
dtype=torch.float32,
)
torch.testing.assert_close(sigreg.calls[0], expected_lewm_history.transpose(0, 1))
def test_select_action_only_regenerates_when_action_queue_is_empty(self):
agent, _head, _agent_module = self._make_agent(pred_horizon=4, obs_horizon=2, num_action_steps=2)
observation = {
@@ -851,6 +1037,46 @@ class IMFVLAAgentTest(unittest.TestCase):
self.assertEqual(agent.vision_encoder.output_dim, 96)
self.assertEqual(agent.vision_encoder.eval_image_resize_shape, (256, 256))
def test_hydra_config_instantiates_lewm_resnet_query_imf_attnres_with_future_tokens(self):
cfg = _compose_cfg(
overrides=[
'agent=lewm_resnet_query_imf_attnres',
'agent.head.n_layer=1',
'agent.head.n_emb=16',
'agent.lewm_query_offsets=[8]',
]
)
self.assertEqual(cfg.agent._target_, 'roboimi.vla.agent_imf.IMFVLAAgent')
self.assertEqual(
cfg.agent.vision_backbone._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.LeWMMultiViewResNetBackbone',
)
self.assertEqual(
cfg.agent.state_encoder._target_,
'roboimi.vla.modules.encoders.LeWMStateEncoder',
)
self.assertEqual(cfg.agent.head.cond_dim, 288)
self.assertEqual(cfg.agent.cond_projector.output_dim, 288)
self.assertEqual(cfg.agent.extra_condition_tokens, 1)
self.assertEqual(
cfg.agent.lewm_sigreg._target_,
'roboimi.vla.models.backbones.lewm_resnet_query_fusion.SIGReg',
)
self.assertAlmostEqual(cfg.agent.lewm_sigreg_weight, 0.09)
with _stub_optional_modules(include_imf_head=True):
agent = instantiate(cfg.agent)
self.assertEqual(agent.per_step_cond_dim, 288)
self.assertEqual(agent.condition_sequence_length, agent.obs_horizon + 1)
self.assertEqual(agent.noise_pred_net.constructor_kwargs['cond_dim'], 288)
self.assertEqual(
agent.noise_pred_net.constructor_kwargs['n_obs_steps'],
agent.condition_sequence_length,
)
self.assertIsNotNone(agent.lewm_sigreg)
def test_hydra_config_instantiates_resnet_imf_attnres_multitoken_with_sequence_length_three_times_obs_horizon(self):
cfg = _compose_cfg(

View File

@@ -79,3 +79,24 @@ class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_can_emit_lewm_history_and_future_observations(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
lewm_history_horizon=3,
lewm_query_offsets=[1, 2],
)
sample = dataset[1]
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))

View File

@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
return {}
class RecordingAgent(FakeAgent):
def __init__(self):
super().__init__()
self.seen_inputs = []
def compute_loss(self, agent_input):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.zeros(2))
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _make_lewm_batch(self):
batch = self._make_batch()
batch.update(
{
'lewm.observation.front': torch.ones(1, 3, 2, 2),
'lewm.observation.state': torch.ones(1, 4),
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
'lewm.future.state': torch.full((1, 4), 2.0),
}
)
return batch
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
return factory
def _lewm_loader_factory(self):
train_batch = self._make_lewm_batch()
val_batch = self._make_lewm_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 1
cfg.train.save_freq = 100
agent = RecordingAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertGreaterEqual(len(agent.seen_inputs), 1)
first_input = agent.seen_inputs[0]
self.assertIn('lewm_images', first_input)
self.assertIn('lewm_qpos', first_input)
self.assertIn('lewm_future_images', first_input)
self.assertIn('lewm_future_qpos', first_input)
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 0
cfg.train.save_freq = 100
cfg.train.pretrained_ckpt = 'pretrained.pt'
agent = ShapeMixedFakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_torch_load(path, map_location=None):
del map_location
if Path(path).name != 'pretrained.pt':
raise AssertionError(f'unexpected load path: {path}')
return {
'model_state_dict': {
'weight': torch.tensor(3.0),
'bias': torch.tensor([1.0, 2.0, 3.0]),
},
'step': 123,
'loss': 0.5,
}
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
Path('pretrained.pt').write_bytes(b'pretend')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)