feat: add held-out validation and dual-decoder lewm imf
This commit is contained in:
@@ -41,6 +41,19 @@ class FakeDataset:
|
||||
return 4
|
||||
|
||||
|
||||
class SplitAwareFakeDataset(FakeDataset):
|
||||
def __init__(self, episode_indices=None):
|
||||
self.episode_indices = None if episode_indices is None else list(episode_indices)
|
||||
if self.episode_indices is None:
|
||||
self.episodes = {0: [0], 1: [1], 2: [2]}
|
||||
else:
|
||||
self.episodes = {idx: [idx] for idx in self.episode_indices}
|
||||
|
||||
@property
|
||||
def available_episode_indices(self):
|
||||
return sorted(self.episodes.keys())
|
||||
|
||||
|
||||
class FakeLoader:
|
||||
def __init__(self, batch):
|
||||
self.batch = batch
|
||||
@@ -123,6 +136,10 @@ class RecordingAgent(FakeAgent):
|
||||
self.seen_inputs.append(agent_input)
|
||||
return super().compute_loss(agent_input)
|
||||
|
||||
def predict_action_chunk(self, agent_input):
|
||||
self.seen_inputs.append({'predict_action_chunk': agent_input})
|
||||
return torch.ones_like(agent_input['action'])
|
||||
|
||||
|
||||
class ShapeMixedFakeAgent(FakeAgent):
|
||||
def __init__(self):
|
||||
@@ -355,6 +372,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
batch_size=2,
|
||||
num_workers=0,
|
||||
val_split=0.25,
|
||||
val_episode_indices=None,
|
||||
action_mse_val_freq_epochs=0,
|
||||
seed=0,
|
||||
lr=1e-3,
|
||||
max_steps=2,
|
||||
@@ -479,6 +498,8 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
'batch_size': 2,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.25,
|
||||
'val_episode_indices': None,
|
||||
'action_mse_val_freq_epochs': 0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
@@ -561,6 +582,58 @@ class TrainVLASwanLabLoggingTest(unittest.TestCase):
|
||||
self.assertIn('front', first_input['lewm_images'])
|
||||
self.assertIn('front', first_input['lewm_future_images'])
|
||||
|
||||
def test_run_training_logs_epoch_action_mse_for_held_out_val_episode(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
cfg = self._make_cfg()
|
||||
cfg.train.max_steps = 1
|
||||
cfg.train.save_freq = 100
|
||||
cfg.train.val_split = 0.0
|
||||
cfg.train.val_episode_indices = [2]
|
||||
cfg.train.action_mse_val_freq_epochs = 1
|
||||
agent = RecordingAgent()
|
||||
fake_swanlab = FakeSwanLab()
|
||||
real_import_module = importlib.import_module
|
||||
|
||||
def fake_instantiate(config_node, **kwargs):
|
||||
if config_node is cfg.data:
|
||||
return SplitAwareFakeDataset(kwargs.get('episode_indices'))
|
||||
if config_node is cfg.agent:
|
||||
return agent
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_loader_factory(dataset, *, shuffle, **_kwargs):
|
||||
action_value = 0.0 if shuffle else 2.0
|
||||
batch = {
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.full((1, 1, 2), action_value),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
}
|
||||
return FakeLoader(batch)
|
||||
|
||||
def fake_import_module(name, package=None):
|
||||
if name == 'swanlab':
|
||||
return fake_swanlab
|
||||
return real_import_module(name, package)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(module, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(module, 'DataLoader', side_effect=fake_loader_factory), \
|
||||
mock.patch.object(module, 'get_lr_schedule_with_warmup', return_value=FakeScheduler()), \
|
||||
mock.patch.object(module, 'tqdm', side_effect=lambda iterable, **kwargs: FakeProgressBar(iterable)), \
|
||||
mock.patch.object(module.torch, 'save', return_value=None), \
|
||||
mock.patch.object(module.importlib, 'import_module', side_effect=fake_import_module):
|
||||
run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
logged_keys = set().union(*(payload.keys() for payload, _ in fake_swanlab.log_calls))
|
||||
self.assertIn('val/action_mse', logged_keys)
|
||||
|
||||
def test_run_training_skips_swanlab_when_disabled(self):
|
||||
module = self._load_train_vla_module()
|
||||
run_training = self._get_run_training(module)
|
||||
|
||||
Reference in New Issue
Block a user