feat(lewm): enable gpu parallel rollout validation
This commit is contained in:
@@ -126,6 +126,26 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertIn("headless", eval_cfg)
|
||||
self.assertFalse(eval_cfg.headless)
|
||||
|
||||
def test_eval_config_exposes_num_workers_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("num_workers", eval_cfg)
|
||||
self.assertEqual(eval_cfg.num_workers, 1)
|
||||
|
||||
def test_eval_config_exposes_cuda_devices_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("cuda_devices", eval_cfg)
|
||||
self.assertIsNone(eval_cfg.cuda_devices)
|
||||
|
||||
def test_eval_config_exposes_parallel_timeout_defaults(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("response_timeout_s", eval_cfg)
|
||||
self.assertIn("server_startup_timeout_s", eval_cfg)
|
||||
self.assertEqual(eval_cfg.response_timeout_s, 300.0)
|
||||
self.assertEqual(eval_cfg.server_startup_timeout_s, 300.0)
|
||||
|
||||
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||
fake_env = object()
|
||||
|
||||
@@ -327,6 +347,172 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
def test_run_eval_uses_serial_path_when_num_workers_is_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 1,
|
||||
"num_episodes": 3,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_uses_serial_path_when_requested_workers_collapse_to_one(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"eval": {
|
||||
"num_workers": 8,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_serial",
|
||||
return_value={"mode": "serial"},
|
||||
) as run_eval_serial, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel",
|
||||
) as run_eval_parallel:
|
||||
result = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "serial"})
|
||||
run_eval_serial.assert_called_once_with(cfg)
|
||||
run_eval_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_requires_headless_true(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": False,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "headless=true"):
|
||||
eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cpu_workers_when_device_is_cpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
return_value={"mode": "cpu"},
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cpu"})
|
||||
run_cpu_parallel.assert_called_once_with(cfg)
|
||||
run_cuda_parallel.assert_not_called()
|
||||
|
||||
def test_run_eval_parallel_dispatches_to_cuda_servers_when_device_is_cuda(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"num_workers": 2,
|
||||
"max_timesteps": 1,
|
||||
"device": "cuda",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
"cuda_devices": [0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cpu",
|
||||
create=True,
|
||||
) as run_cpu_parallel, mock.patch.object(
|
||||
eval_vla,
|
||||
"_run_eval_parallel_cuda",
|
||||
return_value={"mode": "cuda"},
|
||||
create=True,
|
||||
) as run_cuda_parallel:
|
||||
result = eval_vla._run_eval_parallel(cfg)
|
||||
|
||||
self.assertEqual(result, {"mode": "cuda"})
|
||||
run_cpu_parallel.assert_not_called()
|
||||
run_cuda_parallel.assert_called_once_with(cfg)
|
||||
|
||||
def test_resolve_cuda_devices_defaults_to_single_logical_gpu(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": None,
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(eval_vla._resolve_cuda_devices(cfg), [0])
|
||||
|
||||
def test_resolve_cuda_devices_rejects_empty_selection(self):
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"device": "cuda",
|
||||
"cuda_devices": [],
|
||||
}
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "cuda_devices"):
|
||||
eval_vla._resolve_cuda_devices(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user