feat: add held-out validation and dual-decoder lewm imf

This commit is contained in:
Logic
2026-04-17 19:26:56 +08:00
parent 74f4963613
commit 395f5a1645
8 changed files with 693 additions and 86 deletions

View File

@@ -41,6 +41,19 @@ class FakeDataset:
return 4
class SplitAwareFakeDataset(FakeDataset):
def __init__(self, episode_indices=None):
self.episode_indices = None if episode_indices is None else list(episode_indices)
if self.episode_indices is None:
self.episodes = {0: [0], 1: [1], 2: [2]}
else:
self.episodes = {idx: [idx] for idx in self.episode_indices}
@property
def available_episode_indices(self):
return sorted(self.episodes.keys())
class FakeLoader:
def __init__(self, batch):
self.batch = batch
@@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent):
self.seen_inputs.append(agent_input)
return super().compute_loss(agent_input)
def predict_action_chunk(self, agent_input):
self.seen_inputs.append({'predict_action_chunk': agent_input})
return torch.ones_like(agent_input['action'])
class ShapeMixedFakeAgent(FakeAgent):
def __init__(self):
@@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
batch_size=2,
num_workers=0,
val_split=0.25,
val_episode_indices=None,
action_mse_val_freq_epochs=0,
seed=0,
lr=1e-3,
max_steps=2,
@@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
'batch_size': 2,
'num_workers': 0,
'val_split': 0.25,
'val_episode_indices': None,
'action_mse_val_freq_epochs': 0,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
@@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
self.assertIn('front', first_input['lewm_images'])
self.assertIn('front', first_input['lewm_future_images'])
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)
cfg = self._make_cfg()
cfg.train.max_steps = 1
cfg.train.save_freq = 100
cfg.train.val_split = 0.0
cfg.train.val_episode_indices = [2]
cfg.train.action_mse_val_freq_epochs = 1
agent = RecordingAgent()
fake_swanlab = FakeSwanLab()
real_import_module = importlib.import_module
def fake_instantiate(config_node, **kwargs):
if config_node is cfg.data:
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
if config_node is cfg.agent:
return agent
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
action_value = 0.0 if shuffle else 2.0
batch = {
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.full((1, 1, 2), action_value),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
}
return FakeLoader(batch)
def fake_import_module(name, package=None):
if name == 'swanlab':
return fake_swanlab
return real_import_module(name, package)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
mock.patch.object(module.torch, 'save', return_value=None), \
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
run_training(cfg)
finally:
os.chdir(previous_cwd)
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
self.assertIn('val/action_mse', logged_keys)
def test_run_training_skips_swanlab_when_disabled(self):
module = self._load_train_vla_module()
run_training = self._get_run_training(module)