import unittest from unittest import mock import numpy as np import torch from omegaconf import OmegaConf from roboimi.demos.vla_scripts import eval_vla from roboimi.vla.eval_utils import execute_policy_action class _FakeEnv: def __init__(self): self.calls = [] def step(self, action): self.calls.append(("step", action)) def step_jnt(self, action): self.calls.append(("step_jnt", action)) class _FakeQueue: def __init__(self, initial_items=None): self.items = list(initial_items or []) self.put_calls = [] def put(self, item): self.put_calls.append(item) self.items.append(item) def get(self, timeout=None): del timeout if not self.items: raise AssertionError("queue unexpectedly empty") return self.items.pop(0) def _make_parallel_cfg(**eval_overrides): eval_cfg = { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 5, "num_workers": 2, "max_timesteps": 1, "device": "cpu", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": True, "artifact_dir": None, "save_artifacts": False, "save_summary_json": False, "save_timing": False, "save_trajectory": False, "save_trajectory_npz": False, "record_video": False, "save_trajectory_image": False, } eval_cfg.update(eval_overrides) return OmegaConf.create({"agent": {}, "eval": eval_cfg}) class EvalVLAExecutionTest(unittest.TestCase): def test_execute_policy_action_uses_ee_step(self): env = _FakeEnv() action = [1, 2, 3] execute_policy_action(env, action) self.assertEqual(env.calls, [("step", action)]) def test_split_episode_indices_balances_workers(self): self.assertEqual( eval_vla._split_episode_indices(num_episodes=10, num_workers=3), [[0, 1, 2, 3], [4, 5, 6], [7, 8, 9]], ) def test_normalize_num_workers_caps_worker_count_to_episode_count(self): self.assertEqual(eval_vla._normalize_num_workers(num_workers=5, num_episodes=2), 2) def test_plan_episode_box_poses_uses_global_episode_order(self): planned_poses = [ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([1.1, 1.2, 1.3], dtype=np.float32), np.array([2.1, 2.2, 2.3], dtype=np.float32), ] sampler = mock.Mock(side_effect=planned_poses) result = eval_vla._plan_episode_box_poses(num_episodes=3, sampler=sampler) self.assertEqual(sampler.call_count, 3) self.assertEqual(len(result), 3) for expected, actual in zip(planned_poses, result): np.testing.assert_array_equal(actual, expected) def test_resolve_policy_camera_names_matches_vlaagent_fallback_sorting(self): cfg = OmegaConf.create( { "agent": { "_target_": "roboimi.vla.agent.VLAAgent", }, "eval": { "camera_names": ["r_vis", "top", "front"], }, } ) self.assertEqual( eval_vla._resolve_policy_camera_names(cfg), ["front", "r_vis", "top"], ) def test_resolve_policy_camera_names_matches_gr00t_fallback_input_order(self): cfg = OmegaConf.create( { "agent": { "_target_": "roboimi.vla.agent_gr00t_dit.VLAAgentGr00tDiT", }, "eval": { "camera_names": ["r_vis", "top", "front"], }, } ) self.assertEqual( eval_vla._resolve_policy_camera_names(cfg), ["r_vis", "top", "front"], ) def test_build_episode_plans_without_box_poses_keeps_serial_sampling_lazy(self): plans = eval_vla._build_episode_plans(num_episodes=3) self.assertEqual( plans, [ {"episode_index": 0}, {"episode_index": 1}, {"episode_index": 2}, ], ) def test_prepare_local_policy_batch_pads_latest_observation_to_obs_horizon(self): queues = eval_vla._new_local_policy_queues(obs_horizon=3) observation = { "qpos": torch.tensor([1.0, 2.0], dtype=torch.float32), "images": { "front": torch.tensor([[[1.0]]], dtype=torch.float32), }, } eval_vla._populate_local_policy_queues(queues, observation) batch = eval_vla._prepare_local_policy_batch( queues, obs_horizon=3, camera_names=["front"], ) self.assertEqual(tuple(batch["qpos"].shape), (1, 3, 2)) self.assertEqual(tuple(batch["images"]["front"].shape), (1, 3, 1, 1, 1)) np.testing.assert_array_equal( batch["qpos"][0].cpu().numpy(), np.array([[1.0, 2.0], [1.0, 2.0], [1.0, 2.0]], dtype=np.float32), ) np.testing.assert_array_equal( batch["images"]["front"][0].cpu().numpy(), np.array([[[[1.0]]], [[[1.0]]], [[[1.0]]]], dtype=np.float32), ) def test_enqueue_predicted_actions_uses_executable_slice(self): queues = eval_vla._new_local_policy_queues(obs_horizon=2) predicted_actions = torch.tensor( [[[10.0], [20.0], [30.0], [40.0]]], dtype=torch.float32, ) eval_vla._enqueue_predicted_actions( queues, predicted_actions=predicted_actions, obs_horizon=2, num_action_steps=2, ) self.assertEqual(len(queues["action"]), 2) np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([20.0], dtype=np.float32)) np.testing.assert_array_equal(queues["action"].popleft().numpy(), np.array([30.0], dtype=np.float32)) def test_remote_policy_runner_only_requests_server_inference_when_local_action_queue_is_empty(self): request_queue = _FakeQueue() response_queue = _FakeQueue( [ { "type": "predict_chunk_result", "actions": np.asarray([[[10.0], [20.0], [30.0]]], dtype=np.float32), } ] ) runner = eval_vla._RemotePolicyRunner( worker_index=3, server_index=1, request_queue=request_queue, response_queue=response_queue, camera_names=["front"], obs_horizon=2, num_action_steps=2, ) first_observation = { "qpos": torch.tensor([1.0, 2.0], dtype=torch.float32), "images": {"front": torch.tensor([[[1.0]]], dtype=torch.float32)}, } second_observation = { "qpos": torch.tensor([3.0, 4.0], dtype=torch.float32), "images": {"front": torch.tensor([[[2.0]]], dtype=torch.float32)}, } first_action, first_forward = runner.select_action( first_observation, episode_index=7, timestep=0, ) second_action, second_forward = runner.select_action( second_observation, episode_index=7, timestep=1, ) self.assertTrue(first_forward) self.assertFalse(second_forward) self.assertEqual(len(request_queue.put_calls), 1) self.assertEqual(request_queue.put_calls[0]["type"], "predict_chunk") self.assertEqual(request_queue.put_calls[0]["worker_index"], 3) self.assertEqual(request_queue.put_calls[0]["server_index"], 1) np.testing.assert_array_equal(first_action.numpy(), np.array([20.0], dtype=np.float32)) np.testing.assert_array_equal(second_action.numpy(), np.array([30.0], dtype=np.float32)) def test_merge_worker_summaries_sorts_episodes_and_recomputes_aggregates(self): worker_summaries = [ { "avg_inference_fps": 999.0, "avg_control_fps": 999.0, "avg_obs_read_time_ms": 999.0, "avg_total_time_ms": 999.0, "timing_summary": {"count": 999, "model_forward_count": 999}, "episodes": [ { "episode_index": 2, "episode_reward": 9.0, "episode_max_reward": 4.0, "inference_fps": 30.0, "control_fps": 15.0, } ], "_merge_state": { "obs_read_time_ms": [9.0], "preprocess_time_ms": [1.0], "inference_time_ms": [3.0], "env_step_time_ms": [4.0], "total_time_ms": [10.0], "model_forward_flags": [False], }, }, { "avg_inference_fps": 888.0, "avg_control_fps": 888.0, "avg_obs_read_time_ms": 888.0, "avg_total_time_ms": 888.0, "timing_summary": {"count": 888, "model_forward_count": 888}, "episodes": [ { "episode_index": 1, "episode_reward": 6.0, "episode_max_reward": 3.0, "inference_fps": 20.0, "control_fps": 10.0, }, { "episode_index": 0, "episode_reward": 5.0, "episode_max_reward": 2.0, "inference_fps": 10.0, "control_fps": 5.0, }, ], "_merge_state": { "obs_read_time_ms": [1.0, 2.0, 12.0], "preprocess_time_ms": [2.0, 3.0, 4.0], "inference_time_ms": [4.0, 5.0, 6.0], "env_step_time_ms": [6.0, 7.0, 8.0], "total_time_ms": [8.0, 9.0, 20.0], "model_forward_flags": [True, False, True], }, }, ] artifact_paths = { "output_dir": "/tmp/merged", "summary_json": "/tmp/merged/rollout_summary.json", "timing_json": "/tmp/merged/timing.json", "trajectory_npz": None, "video_mp4": None, "video_camera_name": None, } merged = eval_vla._merge_worker_summaries(worker_summaries, artifact_paths) self.assertEqual([episode["episode_index"] for episode in merged["episodes"]], [0, 1, 2]) self.assertEqual(merged["episode_rewards"], [5.0, 6.0, 9.0]) self.assertEqual(merged["episode_max_rewards"], [2.0, 3.0, 4.0]) self.assertAlmostEqual(merged["avg_reward"], 20.0 / 3.0) self.assertAlmostEqual(merged["avg_max_reward"], 3.0) self.assertAlmostEqual(merged["avg_inference_fps"], 20.0) self.assertAlmostEqual(merged["avg_control_fps"], 10.0) self.assertAlmostEqual(merged["avg_obs_read_time_ms"], 6.0) self.assertAlmostEqual(merged["avg_total_time_ms"], 47.0 / 4.0) self.assertEqual(merged["timing_summary"]["count"], 4) self.assertEqual(merged["timing_summary"]["model_forward_count"], 2) self.assertEqual(merged["artifact_dir"], "/tmp/merged") self.assertEqual(merged["artifacts"], artifact_paths) def test_build_cuda_server_payloads_uses_round_robin_worker_assignment(self): cfg = _make_parallel_cfg(num_episodes=4, num_workers=4, device="cuda", cuda_devices=[0, 1]) artifact_paths = {"output_dir": None} with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=[ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([0.4, 0.5, 0.6], dtype=np.float32), np.array([0.7, 0.8, 0.9], dtype=np.float32), np.array([1.0, 1.1, 1.2], dtype=np.float32), ], ): worker_payloads, _ = eval_vla._build_parallel_worker_payloads(cfg, artifact_paths) server_payloads, assigned_workers = eval_vla._build_cuda_server_payloads( cfg, worker_payloads=worker_payloads, cuda_devices=[0, 1], ) self.assertEqual([payload["device_index"] for payload in server_payloads], [0, 1]) self.assertEqual([payload["worker_index"] for payload in assigned_workers], [0, 1, 2, 3]) self.assertEqual([payload["server_index"] for payload in assigned_workers], [0, 1, 0, 1]) self.assertEqual(server_payloads[0]["worker_indices"], [0, 2]) self.assertEqual(server_payloads[1]["worker_indices"], [1, 3]) def test_run_eval_parallel_dispatches_episode_splits_and_box_poses(self): cfg = _make_parallel_cfg(num_episodes=5, num_workers=2, artifact_dir="/tmp/parallel-root") planned_poses = [ np.array([float(index), float(index) + 0.1, float(index) + 0.2], dtype=np.float32) for index in range(5) ] observed_payloads = [] def fake_run_spawn_jobs(payloads, max_workers, worker_fn): del worker_fn self.assertEqual(max_workers, 2) observed_payloads.extend(payloads) return [ { "episodes": [ { "episode_index": 4, "episode_reward": 5.0, "episode_max_reward": 5.0, "inference_fps": 50.0, "control_fps": 25.0, }, { "episode_index": 3, "episode_reward": 4.0, "episode_max_reward": 4.0, "inference_fps": 40.0, "control_fps": 20.0, }, ], "_merge_state": { "obs_read_time_ms": [4.0, 5.0], "preprocess_time_ms": [1.0, 1.0], "inference_time_ms": [2.0, 2.0], "env_step_time_ms": [3.0, 3.0], "total_time_ms": [4.0, 5.0], "model_forward_flags": [True, True], }, }, { "episodes": [ { "episode_index": 2, "episode_reward": 3.0, "episode_max_reward": 3.0, "inference_fps": 30.0, "control_fps": 15.0, }, { "episode_index": 1, "episode_reward": 2.0, "episode_max_reward": 2.0, "inference_fps": 20.0, "control_fps": 10.0, }, { "episode_index": 0, "episode_reward": 1.0, "episode_max_reward": 1.0, "inference_fps": 10.0, "control_fps": 5.0, }, ], "_merge_state": { "obs_read_time_ms": [1.0, 2.0, 3.0], "preprocess_time_ms": [1.0, 1.0, 1.0], "inference_time_ms": [2.0, 2.0, 2.0], "env_step_time_ms": [3.0, 3.0, 3.0], "total_time_ms": [1.0, 2.0, 3.0], "model_forward_flags": [False, True, False], }, }, ] with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=planned_poses, ), mock.patch.object( eval_vla, "_run_spawn_jobs", side_effect=fake_run_spawn_jobs, ): summary = eval_vla._run_eval_parallel(cfg) self.assertEqual(len(observed_payloads), 2) self.assertEqual( [[plan["episode_index"] for plan in payload["episode_plans"]] for payload in observed_payloads], [[0, 1, 2], [3, 4]], ) for payload in observed_payloads: for plan in payload["episode_plans"]: np.testing.assert_array_equal( np.asarray(plan["box_pos"], dtype=np.float32), planned_poses[plan["episode_index"]], ) self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3, 4]) self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0, 5.0]) self.assertEqual(summary["num_episodes"], 5) def test_run_eval_parallel_allows_trajectory_images_and_keeps_worker_artifact_paths(self): cfg = _make_parallel_cfg( num_episodes=2, num_workers=2, artifact_dir="/tmp/parallel-images", save_summary_json=True, save_trajectory_image=True, ) observed_payloads = [] def fake_run_spawn_jobs(payloads, max_workers, worker_fn): del worker_fn self.assertEqual(max_workers, 2) observed_payloads.extend(payloads) return [ { "episodes": [ { "episode_index": 0, "episode_reward": 1.0, "episode_max_reward": 1.0, "inference_fps": 10.0, "control_fps": 5.0, "artifact_paths": { "trajectory_image": f"{payloads[0]['artifact_dir']}/rollout_front_ep01_trajectory.png", }, }, ], "_merge_state": { "obs_read_time_ms": [1.0], "preprocess_time_ms": [1.0], "inference_time_ms": [1.0], "env_step_time_ms": [1.0], "total_time_ms": [1.0], "model_forward_flags": [True], }, }, { "episodes": [ { "episode_index": 1, "episode_reward": 2.0, "episode_max_reward": 2.0, "inference_fps": 20.0, "control_fps": 10.0, "artifact_paths": { "trajectory_image": f"{payloads[1]['artifact_dir']}/rollout_front_ep02_trajectory.png", }, }, ], "_merge_state": { "obs_read_time_ms": [2.0], "preprocess_time_ms": [2.0], "inference_time_ms": [2.0], "env_step_time_ms": [2.0], "total_time_ms": [2.0], "model_forward_flags": [False], }, }, ] with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=[ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([0.4, 0.5, 0.6], dtype=np.float32), ], ), mock.patch.object( eval_vla, "_run_spawn_jobs", side_effect=fake_run_spawn_jobs, ): summary = eval_vla._run_eval_parallel(cfg) self.assertEqual(len(observed_payloads), 2) self.assertTrue(observed_payloads[0]["artifact_dir"].endswith("workers/worker_00")) self.assertTrue(observed_payloads[1]["artifact_dir"].endswith("workers/worker_01")) self.assertTrue( summary["episodes"][0]["artifact_paths"]["trajectory_image"].endswith( "workers/worker_00/rollout_front_ep01_trajectory.png" ) ) self.assertTrue( summary["episodes"][1]["artifact_paths"]["trajectory_image"].endswith( "workers/worker_01/rollout_front_ep02_trajectory.png" ) ) def test_run_eval_parallel_surfaces_worker_failures(self): cfg = _make_parallel_cfg(num_episodes=2, num_workers=2) with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=[ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([0.4, 0.5, 0.6], dtype=np.float32), ], ), mock.patch.object( eval_vla, "_run_spawn_jobs", side_effect=RuntimeError("boom"), ): with self.assertRaisesRegex(RuntimeError, "Parallel rollout worker failed"): eval_vla._run_eval_parallel(cfg) def test_run_eval_parallel_cuda_builds_server_payloads_and_merges_worker_results(self): cfg = _make_parallel_cfg( num_episodes=4, num_workers=4, device="cuda", cuda_devices=[0], artifact_dir="/tmp/cuda-root", ) observed_server_payloads = [] observed_worker_payloads = [] def fake_run_cuda_parallel_processes(server_payloads, worker_payloads): observed_server_payloads.extend(server_payloads) observed_worker_payloads.extend(worker_payloads) return [ { "episodes": [ { "episode_index": 2, "episode_reward": 3.0, "episode_max_reward": 3.0, "inference_fps": 30.0, "control_fps": 15.0, }, { "episode_index": 0, "episode_reward": 1.0, "episode_max_reward": 1.0, "inference_fps": 10.0, "control_fps": 5.0, }, ], "_merge_state": { "obs_read_time_ms": [1.0, 2.0], "preprocess_time_ms": [1.0, 1.0], "inference_time_ms": [2.0, 2.0], "env_step_time_ms": [3.0, 3.0], "total_time_ms": [4.0, 4.0], "model_forward_flags": [True, False], }, }, { "episodes": [ { "episode_index": 3, "episode_reward": 4.0, "episode_max_reward": 4.0, "inference_fps": 40.0, "control_fps": 20.0, }, { "episode_index": 1, "episode_reward": 2.0, "episode_max_reward": 2.0, "inference_fps": 20.0, "control_fps": 10.0, }, ], "_merge_state": { "obs_read_time_ms": [3.0, 4.0], "preprocess_time_ms": [1.0, 1.0], "inference_time_ms": [2.0, 2.0], "env_step_time_ms": [3.0, 3.0], "total_time_ms": [4.0, 4.0], "model_forward_flags": [True, True], }, }, ] with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=[ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([0.4, 0.5, 0.6], dtype=np.float32), np.array([0.7, 0.8, 0.9], dtype=np.float32), np.array([1.0, 1.1, 1.2], dtype=np.float32), ], ), mock.patch.object( eval_vla, "_run_cuda_parallel_processes", side_effect=fake_run_cuda_parallel_processes, create=True, ): summary = eval_vla._run_eval_parallel_cuda(cfg) self.assertEqual(len(observed_server_payloads), 1) self.assertEqual(observed_server_payloads[0]["device_index"], 0) self.assertEqual(len(observed_worker_payloads), 4) self.assertTrue(all(payload["server_index"] == 0 for payload in observed_worker_payloads)) self.assertEqual([episode["episode_index"] for episode in summary["episodes"]], [0, 1, 2, 3]) self.assertEqual(summary["episode_rewards"], [1.0, 2.0, 3.0, 4.0]) self.assertEqual(summary["num_episodes"], 4) def test_run_eval_parallel_cuda_surfaces_server_failures(self): cfg = _make_parallel_cfg(num_episodes=2, num_workers=2, device="cuda", cuda_devices=[0]) with mock.patch.object( eval_vla, "sample_transfer_pose", side_effect=[ np.array([0.1, 0.2, 0.3], dtype=np.float32), np.array([0.4, 0.5, 0.6], dtype=np.float32), ], ), mock.patch.object( eval_vla, "_run_cuda_parallel_processes", side_effect=RuntimeError("server boom"), create=True, ): with self.assertRaisesRegex(RuntimeError, "Parallel CUDA rollout failed"): eval_vla._run_eval_parallel_cuda(cfg) def test_run_spawn_jobs_supports_real_spawn_with_actual_eval_worker_entry(self): payloads = [ {"_spawn_probe": True, "probe_value": 1, "worker_index": 0}, {"_spawn_probe": True, "probe_value": 2, "worker_index": 1}, ] results = eval_vla._run_spawn_jobs( payloads=payloads, max_workers=2, worker_fn=eval_vla._run_eval_worker_entry, ) self.assertEqual(sorted(result["probe_value"] for result in results), [1, 2]) self.assertEqual(sorted(result["worker_index"] for result in results), [0, 1]) def test_cuda_server_and_env_worker_entrypoints_support_real_spawn_probe(self): ctx = eval_vla.multiprocessing.get_context("spawn") request_queue = ctx.Queue() response_queue = ctx.Queue() result_queue = ctx.Queue() server = ctx.Process( target=eval_vla._inference_server_main, args=( { "_spawn_probe": True, "server_index": 0, "request_queue": request_queue, "response_queues": [response_queue], }, ), ) worker = ctx.Process( target=eval_vla._env_worker_main, args=( { "_spawn_probe": True, "worker_index": 0, "server_index": 0, "request_queue": request_queue, "response_queue": response_queue, "result_queue": result_queue, }, ), ) server.start() worker.start() result = result_queue.get(timeout=10.0) worker.join(timeout=10.0) request_queue.put({"type": "shutdown_server"}) server.join(timeout=10.0) self.assertEqual(result["kind"], "worker_result") self.assertEqual(result["summary"]["probe_worker_index"], 0) self.assertEqual(result["summary"]["probe_server_index"], 0) self.assertEqual(result["summary"]["probe_actions"], [[[11.0], [22.0], [33.0]]]) self.assertEqual(worker.exitcode, 0) self.assertEqual(server.exitcode, 0) if __name__ == "__main__": unittest.main()