feat: add lewm-conditioned imf training and sigreg loss

This commit is contained in:
Logic
2026-04-17 18:46:02 +08:00
parent ff7c9c1f2a
commit 74f4963613
14 changed files with 1634 additions and 24 deletions

View File

@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
return {}
class RecordingAgent(FakeAgent):
def __init__(self):
super().__init__()
self.seen_inputs = []
def compute_loss(self, agent_input):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
super().__init__()
self.bias = nn.Parameter(torch.zeros(2))
class FakeSwanLab:
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
self.init_error = init_error
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
def _make_lewm_batch(self):
batch = self._make_batch()
batch.update(
{
'lewm.observation.front': torch.ones(1, 3, 2, 2),
'lewm.observation.state': torch.ones(1, 4),
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
'lewm.future.state': torch.full((1, 4), 2.0),
}
)
return batch
def _loader_factory(self):
train_batch = self._make_batch()
val_batch = self._make_batch()
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
return factory
def _lewm_loader_factory(self):
train_batch = self._make_lewm_batch()
val_batch = self._make_lewm_batch()
def factory(_dataset, *, shuffle, **_kwargs):
return FakeLoader(train_batch if shuffle else val_batch)
return factory
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertEqual(fake_swanlab.finish_calls, 1)
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 1
cfg.train.save_freq = 100
agent = RecordingAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertGreaterEqual(len(agent.seen_inputs), 1)
first_input = agent.seen_inputs[0]
self.assertIn('lewm_images', first_input)
self.assertIn('lewm_qpos', first_input)
self.assertIn('lewm_future_images', first_input)
self.assertIn('lewm_future_qpos', first_input)
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg(use_swanlab=False)
cfg.train.max_steps = 0
cfg.train.save_freq = 100
cfg.train.pretrained_ckpt = 'pretrained.pt'
agent = ShapeMixedFakeAgent()
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return FakeDataset()
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_torch_load(path, map_location=None):
del map_location
if Path(path).name != 'pretrained.pt':
raise AssertionError(f'unexpected load path: {path}')
return {
'model_state_dict': {
'weight': torch.tensor(3.0),
'bias': torch.tensor([1.0, 2.0, 3.0]),
},
'step': 123,
'loss': 0.5,
}
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
Path('pretrained.pt').write_bytes(b'pretend')
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
run_training(cfg)
finally:
os.chdir(previous_cwd)
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)