Files
roboimi/tests/test_eval_vla_headless.py

315 lines
10 KiB
Python

import unittest
from pathlib import Path
from unittest import mock
import numpy as np
import torch
from omegaconf import OmegaConf
from roboimi.demos.vla_scripts import eval_vla
from roboimi.envs.double_base import DualDianaMed
from roboimi.envs.double_pos_ctrl_env import make_sim_env
class _FakeAgent:
def __init__(self):
self.reset_calls = 0
self.last_observation = None
def eval(self):
return self
def to(self, _device):
return self
def reset(self):
self.reset_calls += 1
def select_action(self, observation):
self.last_observation = observation
return torch.zeros(16)
class _FakeEnv:
def __init__(self):
self.image_obs_calls = 0
self.render_calls = 0
self.reset_calls = []
def reset(self, box_pos):
self.reset_calls.append(np.array(box_pos))
def _get_image_obs(self):
self.image_obs_calls += 1
return {
"images": {
"front": np.zeros((8, 8, 3), dtype=np.uint8),
}
}
def _get_qpos_obs(self):
return {"qpos": np.zeros(16, dtype=np.float32)}
def render(self):
self.render_calls += 1
raise AssertionError("env.render() should be skipped when eval.headless=true")
class _RewardTrackingEnv(_FakeEnv):
def __init__(self, reward_sequences):
super().__init__()
self.reward_sequences = reward_sequences
self.episode_index = -1
self.step_index = 0
self.rew = 0.0
def reset(self, box_pos):
super().reset(box_pos)
self.episode_index += 1
self.step_index = 0
class _FakeRenderer:
def __init__(self, env):
self._env = env
self._frames = [
np.full((4, 4, 3), fill_value=index, dtype=np.uint8)
for index in range(5)
]
self._index = 0
def update_scene(self, _mj_data, camera=None):
self._camera = camera
def render(self):
frame = self._frames[self._index]
self._index += 1
if self._index >= len(self._frames):
self._env.exit_flag = True
return frame
class EvalVLAHeadlessTest(unittest.TestCase):
def test_headless_eval_sets_mujoco_gl_to_egl_when_display_missing(self):
cfg = OmegaConf.create({"eval": {"headless": True}})
with mock.patch.dict(eval_vla.os.environ, {}, clear=True):
eval_vla._configure_headless_mujoco_gl(cfg.eval)
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "egl")
def test_headless_eval_preserves_existing_mujoco_gl(self):
cfg = OmegaConf.create({"eval": {"headless": True}})
with mock.patch.dict(eval_vla.os.environ, {"MUJOCO_GL": "osmesa"}, clear=True):
eval_vla._configure_headless_mujoco_gl(cfg.eval)
self.assertEqual(eval_vla.os.environ.get("MUJOCO_GL"), "osmesa")
def test_eval_config_exposes_headless_default(self):
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
self.assertIn("headless", eval_cfg)
self.assertFalse(eval_cfg.headless)
def test_make_sim_env_accepts_headless_and_disables_render(self):
fake_env = object()
with mock.patch(
"roboimi.assets.robots.diana_med.BiDianaMed",
return_value="robot",
), mock.patch(
"roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl",
return_value=fake_env,
) as env_cls:
env = make_sim_env("sim_transfer", headless=True)
self.assertIs(env, fake_env)
env_cls.assert_called_once_with(
robot="robot",
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
)
def test_headless_sync_camera_capture_populates_images_without_gui_calls(self):
env = DualDianaMed.__new__(DualDianaMed)
env.mj_model = object()
env.mj_data = object()
env.exit_flag = False
env.is_render = False
env.cam = 'angle'
env.r_vis = None
env.l_vis = None
env.top = None
env.angle = None
env.front = None
env._offscreen_renderer = None
with mock.patch(
'roboimi.envs.double_base.mj.Renderer',
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
) as renderer_cls, mock.patch('roboimi.envs.double_base.cv2.namedWindow') as named_window, mock.patch(
'roboimi.envs.double_base.cv2.imshow'
) as imshow, mock.patch('roboimi.envs.double_base.cv2.waitKey') as wait_key:
env._update_camera_images_sync()
renderer_cls.assert_called_once()
named_window.assert_not_called()
imshow.assert_not_called()
wait_key.assert_not_called()
self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle)
self.assertIsNotNone(env.front)
def test_cam_start_skips_background_thread_when_headless(self):
env = DualDianaMed.__new__(DualDianaMed)
env.is_render = False
env.cam_thread = None
with mock.patch('roboimi.envs.double_base.threading.Thread') as thread_cls:
env.cam_start()
thread_cls.assert_not_called()
self.assertIsNone(env.cam_thread)
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
env = DualDianaMed.__new__(DualDianaMed)
env.mj_model = object()
env.mj_data = object()
env.exit_flag = False
env.is_render = False
env.cam = "angle"
env.r_vis = None
env.l_vis = None
env.top = None
env.angle = None
env.front = None
with mock.patch(
"roboimi.envs.double_base.mj.Renderer",
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch(
"roboimi.envs.double_base.cv2.imshow"
) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key:
env.camera_viewer()
named_window.assert_not_called()
imshow.assert_not_called()
wait_key.assert_not_called()
self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle)
self.assertIsNotNone(env.front)
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
fake_env = _FakeEnv()
fake_agent = _FakeAgent()
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 1,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
) as make_env, mock.patch.object(
eval_vla,
"sample_transfer_pose",
return_value=np.array([0.1, 0.2, 0.3]),
), mock.patch.object(
eval_vla,
"execute_policy_action",
) as execute_policy_action, mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
eval_vla.main.__wrapped__(cfg)
make_env.assert_called_once_with("sim_transfer", headless=True)
execute_policy_action.assert_called_once()
self.assertEqual(fake_env.image_obs_calls, 1)
self.assertEqual(fake_env.render_calls, 0)
self.assertIsNotNone(fake_agent.last_observation)
self.assertIn("front", fake_agent.last_observation["images"])
def test_run_eval_returns_average_reward_summary(self):
reward_sequences = [
[1.0, 2.0],
[0.5, 4.0],
]
fake_env = _RewardTrackingEnv(reward_sequences)
fake_agent = _FakeAgent()
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 2,
"max_timesteps": 2,
"device": "cpu",
"task_name": "sim_transfer",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
def fake_execute_policy_action(env, action):
del action
env.rew = env.reward_sequences[env.episode_index][env.step_index]
env.step_index += 1
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
), mock.patch.object(
eval_vla,
"sample_transfer_pose",
return_value=np.array([0.1, 0.2, 0.3]),
), mock.patch.object(
eval_vla,
"execute_policy_action",
side_effect=fake_execute_policy_action,
), mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
summary = eval_vla._run_eval(cfg)
self.assertEqual(summary["episode_rewards"], [3.0, 4.5])
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
if __name__ == "__main__":
unittest.main()