import unittest from pathlib import Path from unittest import mock import numpy as np import torch from omegaconf import OmegaConf from roboimi.demos.vla_scripts import eval_vla from roboimi.envs.double_base import DualDianaMed from roboimi.envs.double_pos_ctrl_env import make_sim_env class _FakeAgent: def __init__(self): self.reset_calls = 0 self.last_observation = None def eval(self): return self def to(self, _device): return self def reset(self): self.reset_calls += 1 def select_action(self, observation): self.last_observation = observation return torch.zeros(16) class _FakeEnv: def __init__(self): self.image_obs_calls = 0 self.render_calls = 0 self.reset_calls = [] def reset(self, box_pos): self.reset_calls.append(np.array(box_pos)) def _get_image_obs(self): self.image_obs_calls += 1 return { "images": { "front": np.zeros((8, 8, 3), dtype=np.uint8), } } def _get_qpos_obs(self): return {"qpos": np.zeros(16, dtype=np.float32)} def render(self): self.render_calls += 1 raise AssertionError("env.render() should be skipped when eval.headless=true") class _RewardTrackingEnv(_FakeEnv): def __init__(self, reward_sequences): super().__init__() self.reward_sequences = reward_sequences self.episode_index = -1 self.step_index = 0 self.rew = 0.0 def reset(self, box_pos): super().reset(box_pos) self.episode_index += 1 self.step_index = 0 class _FakeRenderer: def __init__(self, env): self._env = env self._frames = [ np.full((4, 4, 3), fill_value=index, dtype=np.uint8) for index in range(5) ] self._index = 0 def update_scene(self, _mj_data, camera=None): self._camera = camera def render(self): frame = self._frames[self._index] self._index += 1 if self._index >= len(self._frames): self._env.exit_flag = True return frame class EvalVLAHeadlessTest(unittest.TestCase): def test_prepare_observation_skips_resize_when_image_resize_shape_is_none(self): obs = { "images": { "front": np.arange(8 * 8 * 3, dtype=np.uint8).reshape(8, 8, 3), }, "qpos": np.zeros(16, dtype=np.float32), } with mock.patch("cv2.resize", side_effect=AssertionError("resize should be skipped")): prepared = eval_vla.prepare_observation( obs, ["front"], image_resize_shape=None, ) self.assertEqual(tuple(prepared["images"]["front"].shape), (3, 8, 8)) self.assertEqual(tuple(prepared["qpos"].shape), (16,)) def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self): cfg = OmegaConf.create({"eval": {"headless": True}}) with mock.patch.dict(eval_vla.os.environ, {}, clear=True): eval_vla._configure_headless_mujoco_gl(cfg.eval) self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl") def test_headless_eval_preserves_existing_mujoco_gl(self): cfg = OmegaConf.create({"eval": {"headless": True}}) with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True): eval_vla._configure_headless_mujoco_gl(cfg.eval) self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa") def test_eval_config_exposes_headless_default(self): eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) self.assertIn("headless", eval_cfg) self.assertFalse(eval_cfg.headless) def test_eval_config_exposes_num_workers_default(self): eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) self.assertIn("num_workers", eval_cfg) self.assertEqual(eval_cfg.num_workers, 1) def test_eval_config_exposes_cuda_devices_default(self): eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) self.assertIn("cuda_devices", eval_cfg) self.assertIsNone(eval_cfg.cuda_devices) def test_eval_config_exposes_parallel_timeout_defaults(self): eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml")) self.assertIn("response_timeout_s", eval_cfg) self.assertIn("server_startup_timeout_s", eval_cfg) self.assertEqual(eval_cfg.response_timeout_s, 300.0) self.assertEqual(eval_cfg.server_startup_timeout_s, 300.0) def test_make_sim_env_accepts_headless_and_disables_render(self): fake_env = object() with mock.patch( "roboimi.assets.robots.diana_med.BiDianaMed", return_value="robot", ), mock.patch( "roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl", return_value=fake_env, ) as env_cls: env = make_sim_env("sim_transfer", headless=True) self.assertIs(env, fake_env) env_cls.assert_called_once_with( robot="robot", is_render=False, control_freq=30, is_interpolate=True, cam_view="angle", ) def test_headless_sync_camera_capture_populates_images_without_gui_calls(self): env = DualDianaMed.__new__(DualDianaMed) env.mj_model = object() env.mj_data = object() env.exit_flag = False env.is_render = False env.cam = 'angle' env.r_vis = None env.l_vis = None env.top = None env.angle = None env.front = None env._offscreen_renderer = None with mock.patch( 'roboimi.envs.double_base.mj.Renderer', side_effect=lambda *args, **kwargs: _FakeRenderer(env), ) as renderer_cls, mock.patch('roboimi.envs.double_base.cv2.namedWindow') as named_window, mock.patch( 'roboimi.envs.double_base.cv2.imshow' ) as imshow, mock.patch('roboimi.envs.double_base.cv2.waitKey') as wait_key: env._update_camera_images_sync() renderer_cls.assert_called_once() named_window.assert_not_called() imshow.assert_not_called() wait_key.assert_not_called() self.assertIsNotNone(env.r_vis) self.assertIsNotNone(env.l_vis) self.assertIsNotNone(env.top) self.assertIsNotNone(env.angle) self.assertIsNotNone(env.front) def test_cam_start_skips_background_thread_when_headless(self): env = DualDianaMed.__new__(DualDianaMed) env.is_render = False env.cam_thread = None with mock.patch('roboimi.envs.double_base.threading.Thread') as thread_cls: env.cam_start() thread_cls.assert_not_called() self.assertIsNone(env.cam_thread) def test_camera_viewer_headless_updates_images_without_gui_calls(self): env = DualDianaMed.__new__(DualDianaMed) env.mj_model = object() env.mj_data = object() env.exit_flag = False env.is_render = False env.cam = "angle" env.r_vis = None env.l_vis = None env.top = None env.angle = None env.front = None with mock.patch( "roboimi.envs.double_base.mj.Renderer", side_effect=lambda *args, **kwargs: _FakeRenderer(env), ), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch( "roboimi.envs.double_base.cv2.imshow" ) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key: env.camera_viewer() named_window.assert_not_called() imshow.assert_not_called() wait_key.assert_not_called() self.assertIsNotNone(env.r_vis) self.assertIsNotNone(env.l_vis) self.assertIsNotNone(env.top) self.assertIsNotNone(env.angle) self.assertIsNotNone(env.front) def test_eval_main_headless_skips_render_and_still_executes_policy(self): fake_env = _FakeEnv() fake_agent = _FakeAgent() cfg = OmegaConf.create( { "agent": {}, "eval": { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 1, "max_timesteps": 1, "device": "cpu", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": True, }, } ) with mock.patch.object( eval_vla, "load_checkpoint", return_value=(fake_agent, None), ), mock.patch.object( eval_vla, "make_sim_env", return_value=fake_env, ) as make_env, mock.patch.object( eval_vla, "sample_transfer_pose", return_value=np.array([0.1, 0.2, 0.3]), ), mock.patch.object( eval_vla, "execute_policy_action", ) as execute_policy_action, mock.patch.object( eval_vla, "tqdm", side_effect=lambda iterable, **kwargs: iterable, ): eval_vla.main.__wrapped__(cfg) make_env.assert_called_once_with("sim_transfer", headless=True) execute_policy_action.assert_called_once() self.assertEqual(fake_env.image_obs_calls, 1) self.assertEqual(fake_env.render_calls, 0) self.assertIsNotNone(fake_agent.last_observation) self.assertIn("front", fake_agent.last_observation["images"]) def test_run_eval_returns_average_reward_summary(self): reward_sequences = [ [1.0, 2.0], [0.5, 4.0], ] fake_env = _RewardTrackingEnv(reward_sequences) fake_agent = _FakeAgent() cfg = OmegaConf.create( { "agent": {}, "eval": { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 2, "max_timesteps": 2, "device": "cpu", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": True, }, } ) def fake_execute_policy_action(env, action): del action env.rew = env.reward_sequences[env.episode_index][env.step_index] env.step_index += 1 with mock.patch.object( eval_vla, "load_checkpoint", return_value=(fake_agent, None), ), mock.patch.object( eval_vla, "make_sim_env", return_value=fake_env, ), mock.patch.object( eval_vla, "sample_transfer_pose", return_value=np.array([0.1, 0.2, 0.3]), ), mock.patch.object( eval_vla, "execute_policy_action", side_effect=fake_execute_policy_action, ), mock.patch.object( eval_vla, "tqdm", side_effect=lambda iterable, **kwargs: iterable, ): summary = eval_vla._run_eval(cfg) self.assertEqual(summary["episode_rewards"], [3.0, 4.5]) self.assertAlmostEqual(summary["avg_reward"], 3.75) self.assertEqual(summary["num_episodes"], 2) def test_run_eval_uses_serial_path_when_num_workers_is_one(self): cfg = OmegaConf.create( { "eval": { "num_workers": 1, "num_episodes": 3, } } ) with mock.patch.object( eval_vla, "_run_eval_serial", return_value={"mode": "serial"}, ) as run_eval_serial, mock.patch.object( eval_vla, "_run_eval_parallel", ) as run_eval_parallel: result = eval_vla._run_eval(cfg) self.assertEqual(result, {"mode": "serial"}) run_eval_serial.assert_called_once_with(cfg) run_eval_parallel.assert_not_called() def test_run_eval_uses_serial_path_when_requested_workers_collapse_to_one(self): cfg = OmegaConf.create( { "eval": { "num_workers": 8, "num_episodes": 1, } } ) with mock.patch.object( eval_vla, "_run_eval_serial", return_value={"mode": "serial"}, ) as run_eval_serial, mock.patch.object( eval_vla, "_run_eval_parallel", ) as run_eval_parallel: result = eval_vla._run_eval(cfg) self.assertEqual(result, {"mode": "serial"}) run_eval_serial.assert_called_once_with(cfg) run_eval_parallel.assert_not_called() def test_run_eval_parallel_requires_headless_true(self): cfg = OmegaConf.create( { "agent": {}, "eval": { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 2, "num_workers": 2, "max_timesteps": 1, "device": "cpu", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": False, }, } ) with self.assertRaisesRegex(ValueError, "headless=true"): eval_vla._run_eval_parallel(cfg) def test_run_eval_parallel_dispatches_to_cpu_workers_when_device_is_cpu(self): cfg = OmegaConf.create( { "agent": {}, "eval": { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 2, "num_workers": 2, "max_timesteps": 1, "device": "cpu", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": True, "cuda_devices": None, }, } ) with mock.patch.object( eval_vla, "_run_eval_parallel_cpu", return_value={"mode": "cpu"}, create=True, ) as run_cpu_parallel, mock.patch.object( eval_vla, "_run_eval_parallel_cuda", create=True, ) as run_cuda_parallel: result = eval_vla._run_eval_parallel(cfg) self.assertEqual(result, {"mode": "cpu"}) run_cpu_parallel.assert_called_once_with(cfg) run_cuda_parallel.assert_not_called() def test_run_eval_parallel_dispatches_to_cuda_servers_when_device_is_cuda(self): cfg = OmegaConf.create( { "agent": {}, "eval": { "ckpt_path": "checkpoints/vla_model_best.pt", "num_episodes": 2, "num_workers": 2, "max_timesteps": 1, "device": "cuda", "task_name": "sim_transfer", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, "verbose_action": False, "headless": True, "cuda_devices": [0], }, } ) with mock.patch.object( eval_vla, "_run_eval_parallel_cpu", create=True, ) as run_cpu_parallel, mock.patch.object( eval_vla, "_run_eval_parallel_cuda", return_value={"mode": "cuda"}, create=True, ) as run_cuda_parallel: result = eval_vla._run_eval_parallel(cfg) self.assertEqual(result, {"mode": "cuda"}) run_cpu_parallel.assert_not_called() run_cuda_parallel.assert_called_once_with(cfg) def test_resolve_cuda_devices_defaults_to_single_logical_gpu(self): cfg = OmegaConf.create( { "device": "cuda", "cuda_devices": None, } ) self.assertEqual(eval_vla._resolve_cuda_devices(cfg), [0]) def test_resolve_cuda_devices_rejects_empty_selection(self): cfg = OmegaConf.create( { "device": "cuda", "cuda_devices": [], } ) with self.assertRaisesRegex(ValueError, "cuda_devices"): eval_vla._resolve_cuda_devices(cfg) if __name__ == "__main__": unittest.main()