feat(lewm): enable gpu parallel rollout validation
This commit is contained in:
@@ -158,6 +158,106 @@ class TrainVLARolloutValidationTest(unittest.TestCase):
|
||||
self.assertGreater(float(cfg.train.lr), 5e-5)
|
||||
self.assertGreater(cfg.train.num_workers, 8)
|
||||
self.assertEqual(cfg.train.rollout_val_freq_epochs, 50)
|
||||
self.assertEqual(cfg.train.rollout_device, cfg.train.device)
|
||||
self.assertIsNone(cfg.train.rollout_num_workers)
|
||||
self.assertIsNone(cfg.train.rollout_cuda_devices)
|
||||
|
||||
def test_run_training_rollout_validation_propagates_gpu_parallel_settings(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'train': {
|
||||
'device': 'cpu',
|
||||
'batch_size': 1,
|
||||
'num_workers': 0,
|
||||
'val_split': 0.0,
|
||||
'seed': 0,
|
||||
'lr': 1e-3,
|
||||
'max_steps': 2,
|
||||
'log_freq': 1,
|
||||
'save_freq': 1000,
|
||||
'warmup_steps': 1,
|
||||
'scheduler_type': 'constant',
|
||||
'min_lr': 0.0,
|
||||
'grad_clip': 1.0,
|
||||
'weight_decay': 0.0,
|
||||
'pretrained_ckpt': None,
|
||||
'resume_ckpt': None,
|
||||
'use_swanlab': False,
|
||||
'rollout_val_freq_epochs': 2,
|
||||
'rollout_num_episodes': 5,
|
||||
'rollout_device': 'cuda',
|
||||
'rollout_num_workers': 4,
|
||||
'rollout_cuda_devices': [0, 1],
|
||||
'rollout_response_timeout_s': 123.0,
|
||||
'rollout_server_startup_timeout_s': 456.0,
|
||||
},
|
||||
'data': {
|
||||
'camera_names': ['front'],
|
||||
},
|
||||
'agent': {
|
||||
'_target_': 'fake.agent',
|
||||
},
|
||||
'eval': {
|
||||
'ckpt_path': 'unused.pt',
|
||||
'num_episodes': 99,
|
||||
'max_timesteps': 1,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front'],
|
||||
'use_smoothing': False,
|
||||
'smooth_alpha': 0.3,
|
||||
'verbose_action': False,
|
||||
'headless': False,
|
||||
},
|
||||
}
|
||||
)
|
||||
rollout_mock = mock.Mock(return_value={'avg_reward': 1.0})
|
||||
|
||||
def fake_instantiate(config_node, **_kwargs):
|
||||
if config_node is cfg.data:
|
||||
return _FakeDataset()
|
||||
if config_node is cfg.agent:
|
||||
return _FakeAgent()
|
||||
raise AssertionError(f'unexpected instantiate config: {config_node!r}')
|
||||
|
||||
def fake_dataloader(_dataset, *, shuffle, **_kwargs):
|
||||
del shuffle, _kwargs
|
||||
return _FakeLoader(
|
||||
{
|
||||
'observation.front': torch.zeros(1, 3, 2, 2),
|
||||
'observation.state': torch.zeros(1, 4),
|
||||
'action': torch.zeros(1, 2),
|
||||
'action_is_pad': torch.zeros(1, 1, dtype=torch.bool),
|
||||
},
|
||||
length=1,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch.object(train_vla, 'instantiate', side_effect=fake_instantiate), \
|
||||
mock.patch.object(train_vla, 'DataLoader', side_effect=fake_dataloader), \
|
||||
mock.patch.object(train_vla, 'build_training_optimizer', return_value=_FakeOptimizer(cfg.train.lr)), \
|
||||
mock.patch.object(train_vla, 'get_lr_schedule_with_warmup', return_value=_FakeScheduler()), \
|
||||
mock.patch.object(train_vla, 'tqdm', side_effect=lambda iterable, **kwargs: _FakeProgressBar(iterable)), \
|
||||
mock.patch.object(train_vla.torch, 'save', return_value=None), \
|
||||
mock.patch.object(eval_vla, '_run_eval', rollout_mock, create=True):
|
||||
train_vla._run_training(cfg)
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
rollout_cfg = rollout_mock.call_args.args[0]
|
||||
self.assertEqual(rollout_cfg.eval.device, 'cuda')
|
||||
self.assertEqual(rollout_cfg.eval.num_workers, 4)
|
||||
self.assertEqual(list(rollout_cfg.eval.cuda_devices), [0, 1])
|
||||
self.assertEqual(float(rollout_cfg.eval.response_timeout_s), 123.0)
|
||||
self.assertEqual(float(rollout_cfg.eval.server_startup_timeout_s), 456.0)
|
||||
self.assertTrue(rollout_cfg.eval.headless)
|
||||
self.assertEqual(rollout_cfg.eval.num_episodes, 5)
|
||||
self.assertFalse(rollout_cfg.eval.record_video)
|
||||
self.assertTrue(rollout_cfg.eval.save_summary_json)
|
||||
self.assertTrue(rollout_cfg.eval.save_trajectory_image)
|
||||
|
||||
def test_training_passes_backbone_image_resize_override_to_dataset_instantiation(self):
|
||||
cfg = OmegaConf.create(
|
||||
|
||||
Reference in New Issue
Block a user