feat(vla): align transformer training stack and rollout validation
This commit is contained in:
259
tests/test_eval_vla_headless.py
Normal file
259
tests/test_eval_vla_headless.py
Normal file
@@ -0,0 +1,259 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
from roboimi.envs.double_base import DualDianaMed
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self):
|
||||
self.reset_calls = 0
|
||||
self.last_observation = None
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, _device):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_calls += 1
|
||||
|
||||
def select_action(self, observation):
|
||||
self.last_observation = observation
|
||||
return torch.zeros(16)
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.image_obs_calls = 0
|
||||
self.render_calls = 0
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, box_pos):
|
||||
self.reset_calls.append(np.array(box_pos))
|
||||
|
||||
def _get_image_obs(self):
|
||||
self.image_obs_calls += 1
|
||||
return {
|
||||
"images": {
|
||||
"front": np.zeros((8, 8, 3), dtype=np.uint8),
|
||||
}
|
||||
}
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
return {"qpos": np.zeros(16, dtype=np.float32)}
|
||||
|
||||
def render(self):
|
||||
self.render_calls += 1
|
||||
raise AssertionError("env.render() should be skipped when eval.headless=true")
|
||||
|
||||
|
||||
class _RewardTrackingEnv(_FakeEnv):
|
||||
def __init__(self, reward_sequences):
|
||||
super().__init__()
|
||||
self.reward_sequences = reward_sequences
|
||||
self.episode_index = -1
|
||||
self.step_index = 0
|
||||
self.rew = 0.0
|
||||
|
||||
def reset(self, box_pos):
|
||||
super().reset(box_pos)
|
||||
self.episode_index += 1
|
||||
self.step_index = 0
|
||||
|
||||
|
||||
class _FakeRenderer:
|
||||
def __init__(self, env):
|
||||
self._env = env
|
||||
self._frames = [
|
||||
np.full((4, 4, 3), fill_value=index, dtype=np.uint8)
|
||||
for index in range(5)
|
||||
]
|
||||
self._index = 0
|
||||
|
||||
def update_scene(self, _mj_data, camera=None):
|
||||
self._camera = camera
|
||||
|
||||
def render(self):
|
||||
frame = self._frames[self._index]
|
||||
self._index += 1
|
||||
if self._index >= len(self._frames):
|
||||
self._env.exit_flag = True
|
||||
return frame
|
||||
|
||||
|
||||
class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
def test_eval_config_exposes_headless_default(self):
|
||||
eval_cfg = OmegaConf.load(Path("roboimi/vla/conf/eval/eval.yaml"))
|
||||
|
||||
self.assertIn("headless", eval_cfg)
|
||||
self.assertFalse(eval_cfg.headless)
|
||||
|
||||
def test_make_sim_env_accepts_headless_and_disables_render(self):
|
||||
fake_env = object()
|
||||
|
||||
with mock.patch(
|
||||
"roboimi.assets.robots.diana_med.BiDianaMed",
|
||||
return_value="robot",
|
||||
), mock.patch(
|
||||
"roboimi.envs.double_pos_ctrl_env.DualDianaMed_Pos_Ctrl",
|
||||
return_value=fake_env,
|
||||
) as env_cls:
|
||||
env = make_sim_env("sim_transfer", headless=True)
|
||||
|
||||
self.assertIs(env, fake_env)
|
||||
env_cls.assert_called_once_with(
|
||||
robot="robot",
|
||||
is_render=False,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view="angle",
|
||||
)
|
||||
|
||||
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||
env = DualDianaMed.__new__(DualDianaMed)
|
||||
env.mj_model = object()
|
||||
env.mj_data = object()
|
||||
env.exit_flag = False
|
||||
env.is_render = False
|
||||
env.cam = "angle"
|
||||
env.r_vis = None
|
||||
env.l_vis = None
|
||||
env.top = None
|
||||
env.angle = None
|
||||
env.front = None
|
||||
|
||||
with mock.patch(
|
||||
"roboimi.envs.double_base.mj.Renderer",
|
||||
side_effect=lambda *args, **kwargs: _FakeRenderer(env),
|
||||
), mock.patch("roboimi.envs.double_base.cv2.namedWindow") as named_window, mock.patch(
|
||||
"roboimi.envs.double_base.cv2.imshow"
|
||||
) as imshow, mock.patch("roboimi.envs.double_base.cv2.waitKey") as wait_key:
|
||||
env.camera_viewer()
|
||||
|
||||
named_window.assert_not_called()
|
||||
imshow.assert_not_called()
|
||||
wait_key.assert_not_called()
|
||||
self.assertIsNotNone(env.r_vis)
|
||||
self.assertIsNotNone(env.l_vis)
|
||||
self.assertIsNotNone(env.top)
|
||||
self.assertIsNotNone(env.angle)
|
||||
self.assertIsNotNone(env.front)
|
||||
|
||||
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
|
||||
fake_env = _FakeEnv()
|
||||
fake_agent = _FakeAgent()
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 1,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"load_checkpoint",
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"make_sim_env",
|
||||
return_value=fake_env,
|
||||
) as make_env, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
return_value=np.array([0.1, 0.2, 0.3]),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
) as execute_policy_action, mock.patch.object(
|
||||
eval_vla,
|
||||
"tqdm",
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
):
|
||||
eval_vla.main.__wrapped__(cfg)
|
||||
|
||||
make_env.assert_called_once_with("sim_transfer", headless=True)
|
||||
execute_policy_action.assert_called_once()
|
||||
self.assertEqual(fake_env.image_obs_calls, 1)
|
||||
self.assertEqual(fake_env.render_calls, 0)
|
||||
self.assertIsNotNone(fake_agent.last_observation)
|
||||
self.assertIn("front", fake_agent.last_observation["images"])
|
||||
|
||||
def test_run_eval_returns_average_reward_summary(self):
|
||||
reward_sequences = [
|
||||
[1.0, 2.0],
|
||||
[0.5, 4.0],
|
||||
]
|
||||
fake_env = _RewardTrackingEnv(reward_sequences)
|
||||
fake_agent = _FakeAgent()
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 2,
|
||||
"max_timesteps": 2,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_transfer",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
def fake_execute_policy_action(env, action):
|
||||
del action
|
||||
env.rew = env.reward_sequences[env.episode_index][env.step_index]
|
||||
env.step_index += 1
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"load_checkpoint",
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"make_sim_env",
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
return_value=np.array([0.1, 0.2, 0.3]),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
side_effect=fake_execute_policy_action,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"tqdm",
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
):
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
self.assertEqual(summary["episode_rewards"], [3.0, 4.5])
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user