feat: add lewm-conditioned imf training and sigreg loss
This commit is contained in:
@@ -114,6 +114,22 @@ class FakeAgent(nn.Module):
|
||||
return {}
|
||||
|
||||
|
||||
class RecordingAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seen_inputs = []
|
||||
|
||||
def compute_loss(self, agent_input):
|
||||
self.seen_inputs.append(agent_input)
|
||||
return super().compute_loss(agent_input)
|
||||
|
||||
|
||||
class ShapeMixedFakeAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(2))
|
||||
|
||||
|
||||
class FakeSwanLab:
|
||||
def __init__(self, init_error=None, log_errors=None, finish_error=None, image_errors=None):
|
||||
self.init_error = init_error
|
||||
@@ -388,6 +404,18 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
|
||||
def _make_lewm_batch(self):
|
||||
batch = self._make_batch()
|
||||
batch.update(
|
||||
{
|
||||
'lewm.observation.front': torch.ones(1, 3, 2, 2),
|
||||
'lewm.observation.state': torch.ones(1, 4),
|
||||
'lewm.future.front': torch.full((1, 3, 2, 2), 2.0),
|
||||
'lewm.future.state': torch.full((1, 4), 2.0),
|
||||
}
|
||||
)
|
||||
return batch
|
||||
|
||||
def _loader_factory(self):
|
||||
train_batch = self._make_batch()
|
||||
val_batch = self._make_batch()
|
||||
@@ -397,6 +425,15 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
|
||||
return factory
|
||||
|
||||
def _lewm_loader_factory(self):
|
||||
train_batch = self._make_lewm_batch()
|
||||
val_batch = self._make_lewm_batch()
|
||||
|
||||
def factory(_dataset, *, shuffle, **_kwargs):
|
||||
return FakeLoader(train_batch if shuffle else val_batch)
|
||||
|
||||
return factory
|
||||
|
||||
def test_run_training_logs_metrics_and_checkpoint_paths_to_swanlab(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -487,6 +524,43 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertEqual(fake_swanlab.finish_calls, 1)
|
||||
|
||||
def test_run_training_passes_lewm_history_and_future_batches_into_agent_input(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 1
|
||||
cfg.train.save_freq = 100
|
||||
agent = RecordingAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._lewm_loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertGreaterEqual(len(agent.seen_inputs), 1)
|
||||
first_input = agent.seen_inputs[0]
|
||||
self.assertIn('lewm_images', first_input)
|
||||
self.assertIn('lewm_qpos', first_input)
|
||||
self.assertIn('lewm_future_images', first_input)
|
||||
self.assertIn('lewm_future_qpos', first_input)
|
||||
self.assertIn('front', first_input['lewm_images'])
|
||||
self.assertIn('front', first_input['lewm_future_images'])
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
@@ -668,6 +742,52 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertTrue(final_payload['final/best_checkpoint_path'].endswith('checkpoints/vla_model_best.pt'))
|
||||
self.assertFalse(any(path.endswith('checkpoints/vla_model_best.pt') for path in saved_paths))
|
||||
|
||||
def test_run_training_pretrained_ckpt_loads_matching_keys_even_if_some_shapes_mismatch(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg(use_swanlab=False)
|
||||
cfg.train.max_steps = 0
|
||||
cfg.train.save_freq = 100
|
||||
cfg.train.pretrained_ckpt = 'pretrained.pt'
|
||||
agent = ShapeMixedFakeAgent()
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_torch_load(path, map_location=None):
|
||||
del map_location
|
||||
if Path(path).name != 'pretrained.pt':
|
||||
raise AssertionError(f'unexpected load path: {path}')
|
||||
return {
|
||||
'model_state_dict': {
|
||||
'weight': torch.tensor(3.0),
|
||||
'bias': torch.tensor([1.0, 2.0, 3.0]),
|
||||
},
|
||||
'step': 123,
|
||||
'loss': 0.5,
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
Path('pretrained.pt').write_bytes(b'pretend')
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=self._loader_factory()), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.torch, 'load', side_effect=fake_torch_load):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertAlmostEqual(agent.weight.item(), 3.0, places=6)
|
||||
|
||||
def test_run_training_resume_ignores_best_checkpoint_without_rollout_metric(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
|
||||
Reference in New Issue
Block a user