feat(lewm): enable gpu parallel rollout validation

This commit is contained in:
Logic
2026-04-23 15:11:41 +08:00
parent 4cd33258d2
commit 61522d9ae5
8 changed files with 2561 additions and 121 deletions

View File

@@ -158,6 +158,106 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
self.assertGreater(float(cfg.train.lr), 5e-5)
self.assertGreater(cfg.train.num_workers, 8)
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
self.assertEqual(cfg.train.rollout_device, cfg.train.device)
self.assertIsNone(cfg.train.rollout_num_workers)
self.assertIsNone(cfg.train.rollout_cuda_devices)
def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self):
cfg = OmegaConf.create(
{
'train': {
'device': 'cpu',
'batch_size': 1,
'num_workers': 0,
'val_split': 0.0,
'seed': 0,
'lr': 1e-3,
'max_steps': 2,
'log_freq': 1,
'save_freq': 1000,
'warmup_steps': 1,
'scheduler_type': 'constant',
'min_lr': 0.0,
'grad_clip': 1.0,
'weight_decay': 0.0,
'pretrained_ckpt': None,
'resume_ckpt': None,
'use_swanlab': False,
'rollout_val_freq_epochs': 2,
'rollout_num_episodes': 5,
'rollout_device': 'cuda',
'rollout_num_workers': 4,
'rollout_cuda_devices': [0, 1],
'rollout_response_timeout_s': 123.0,
'rollout_server_startup_timeout_s': 456.0,
},
'data': {
'camera_names': ['front'],
},
'agent': {
'_target_': 'fake.agent',
},
'eval': {
'ckpt_path': 'unused.pt',
'num_episodes': 99,
'max_timesteps': 1,
'device': 'cpu',
'task_name': 'sim_transfer',
'camera_names': ['front'],
'use_smoothing': False,
'smooth_alpha': 0.3,
'verbose_action': False,
'headless': False,
},
}
)
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
def fake_instantiate(config_node, **_kwargs):
if config_node is cfg.data:
return _FakeDataset()
if config_node is cfg.agent:
return _FakeAgent()
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
del shuffle, _kwargs
return _FakeLoader(
{
'observation.front': torch.zeros(1, 3, 2, 2),
'observation.state': torch.zeros(1, 4),
'action': torch.zeros(1, 2),
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
},
length=1,
)
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
mock.patch.object(train_vla.torch, 'save', return_value=None), \
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
train_vla._run_training(cfg)
finally:
os.chdir(previous_cwd)
rollout_cfg = rollout_mock.call_args.args[0]
self.assertEqual(rollout_cfg.eval.device, 'cuda')
self.assertEqual(rollout_cfg.eval.num_workers, 4)
self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1])
self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0)
self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0)
self.assertTrue(rollout_cfg.eval.headless)
self.assertEqual(rollout_cfg.eval.num_episodes, 5)
self.assertFalse(rollout_cfg.eval.record_video)
self.assertTrue(rollout_cfg.eval.save_summary_json)
self.assertTrue(rollout_cfg.eval.save_trajectory_image)
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
cfg = OmegaConf.create(