feat(eval): export rollout video timing and ee trajectory
This commit is contained in:
228
tests/test_eval_vla_rollout_artifacts.py
Normal file
228
tests/test_eval_vla_rollout_artifacts.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from roboimi.demos.vla_scripts import eval_vla
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
def __init__(self, actions):
|
||||
self._actions = [torch.tensor(action, dtype=torch.float32) for action in actions]
|
||||
self.reset_calls = 0
|
||||
|
||||
def eval(self):
|
||||
return self
|
||||
|
||||
def to(self, _device):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.reset_calls += 1
|
||||
|
||||
def select_action(self, observation):
|
||||
del observation
|
||||
return self._actions.pop(0)
|
||||
|
||||
|
||||
class _FakeEnv:
|
||||
def __init__(self):
|
||||
self.step_count = 0
|
||||
self.rew = 0.0
|
||||
self.render_calls = 0
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, box_pos):
|
||||
self.reset_calls.append(np.array(box_pos, copy=True))
|
||||
self.step_count = 0
|
||||
self.rew = 0.0
|
||||
|
||||
def _get_image_obs(self):
|
||||
frame_value = self.step_count
|
||||
front = np.full((6, 8, 3), fill_value=frame_value, dtype=np.uint8)
|
||||
top = np.full((6, 8, 3), fill_value=frame_value + 20, dtype=np.uint8)
|
||||
return {"images": {"front": front, "top": top}}
|
||||
|
||||
def _get_qpos_obs(self):
|
||||
return {"qpos": np.arange(16, dtype=np.float32)}
|
||||
|
||||
def step(self, action):
|
||||
del action
|
||||
self.step_count += 1
|
||||
self.rew = float(self.step_count)
|
||||
|
||||
def render(self):
|
||||
self.render_calls += 1
|
||||
|
||||
def getBodyPos(self, name):
|
||||
base = float(self.step_count)
|
||||
if name == 'eef_left':
|
||||
return np.array([base, base + 0.1, base + 0.2], dtype=np.float32)
|
||||
if name == 'eef_right':
|
||||
return np.array([base + 1.0, base + 1.1, base + 1.2], dtype=np.float32)
|
||||
raise KeyError(name)
|
||||
|
||||
def getBodyQuat(self, name):
|
||||
base = float(self.step_count)
|
||||
if name == 'eef_left':
|
||||
return np.array([1.0, base, 0.0, 0.0], dtype=np.float32)
|
||||
if name == 'eef_right':
|
||||
return np.array([1.0, 0.0, base, 0.0], dtype=np.float32)
|
||||
raise KeyError(name)
|
||||
|
||||
|
||||
class _FakeVideoWriter:
|
||||
def __init__(self, output_path):
|
||||
self.output_path = Path(output_path)
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.output_path.write_bytes(b'')
|
||||
self.frames = []
|
||||
self.released = False
|
||||
|
||||
def isOpened(self):
|
||||
return True
|
||||
|
||||
def write(self, frame):
|
||||
self.frames.append(np.array(frame, copy=True))
|
||||
|
||||
def release(self):
|
||||
self.released = True
|
||||
self.output_path.write_bytes(b'fake-mp4')
|
||||
|
||||
|
||||
class EvalVLARolloutArtifactsTest(unittest.TestCase):
|
||||
def test_eval_config_exposes_rollout_artifact_defaults(self):
|
||||
eval_cfg = OmegaConf.load(Path('roboimi/vla/conf/eval/eval.yaml'))
|
||||
|
||||
self.assertIn('artifact_dir', eval_cfg)
|
||||
self.assertFalse(eval_cfg.save_summary_json)
|
||||
self.assertFalse(eval_cfg.save_trajectory_npz)
|
||||
self.assertFalse(eval_cfg.record_video)
|
||||
self.assertIsNone(eval_cfg.artifact_dir)
|
||||
self.assertIsNone(eval_cfg.video_camera_name)
|
||||
self.assertEqual(eval_cfg.video_fps, 30)
|
||||
|
||||
def test_run_eval_exports_npz_summary_and_video_artifacts(self):
|
||||
actions = [
|
||||
np.arange(16, dtype=np.float32),
|
||||
np.arange(16, dtype=np.float32) + 10.0,
|
||||
]
|
||||
fake_agent = _FakeAgent(actions)
|
||||
fake_env = _FakeEnv()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
'agent': {},
|
||||
'eval': {
|
||||
'ckpt_path': 'checkpoints/vla_model_best.pt',
|
||||
'num_episodes': 1,
|
||||
'max_timesteps': 2,
|
||||
'device': 'cpu',
|
||||
'task_name': 'sim_transfer',
|
||||
'camera_names': ['front', 'top'],
|
||||
'use_smoothing': True,
|
||||
'smooth_alpha': 0.5,
|
||||
'verbose_action': False,
|
||||
'headless': True,
|
||||
'artifact_dir': tmpdir,
|
||||
'save_summary_json': True,
|
||||
'save_trajectory_npz': True,
|
||||
'record_video': True,
|
||||
'video_camera_name': 'front',
|
||||
'video_fps': 12,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
writer_holder = {}
|
||||
|
||||
def fake_open_video_writer(output_path, frame_size, fps):
|
||||
self.assertEqual(frame_size, (8, 6))
|
||||
self.assertEqual(fps, 12)
|
||||
writer = _FakeVideoWriter(output_path)
|
||||
writer_holder['writer'] = writer
|
||||
return writer
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
'load_checkpoint',
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'make_sim_env',
|
||||
return_value=fake_env,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'sample_transfer_pose',
|
||||
return_value=np.array([0.1, 0.2, 0.3], dtype=np.float32),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'tqdm',
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
'_open_video_writer',
|
||||
side_effect=fake_open_video_writer,
|
||||
):
|
||||
summary = eval_vla._run_eval(cfg)
|
||||
|
||||
artifacts = summary['artifacts']
|
||||
trajectory_path = Path(artifacts['trajectory_npz'])
|
||||
summary_path = Path(artifacts['summary_json'])
|
||||
video_path = Path(artifacts['video_mp4'])
|
||||
|
||||
self.assertEqual(Path(artifacts['output_dir']), Path(tmpdir))
|
||||
self.assertEqual(artifacts['video_camera_name'], 'front')
|
||||
self.assertTrue(trajectory_path.exists())
|
||||
self.assertTrue(summary_path.exists())
|
||||
self.assertTrue(video_path.exists())
|
||||
|
||||
rollout_npz = np.load(trajectory_path)
|
||||
np.testing.assert_array_equal(rollout_npz['episode_index'], np.array([0, 0]))
|
||||
np.testing.assert_array_equal(rollout_npz['timestep'], np.array([0, 1]))
|
||||
np.testing.assert_array_equal(rollout_npz['reward'], np.array([1.0, 2.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(rollout_npz['raw_predicted_ee_action'][0], actions[0])
|
||||
np.testing.assert_array_equal(rollout_npz['raw_predicted_ee_action'][1], actions[1])
|
||||
np.testing.assert_array_equal(rollout_npz['executed_ee_action'][0], actions[0])
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['executed_ee_action'][1],
|
||||
(actions[0] + actions[1]) / 2.0,
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['left_ee_pos'],
|
||||
np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=np.float32),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
rollout_npz['right_ee_pos'],
|
||||
np.array([[2.0, 2.1, 2.2], [3.0, 3.1, 3.2]], dtype=np.float32),
|
||||
)
|
||||
self.assertEqual(rollout_npz['obs_read_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['preprocess_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['inference_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['env_step_time_ms'].shape, (2,))
|
||||
self.assertEqual(rollout_npz['total_time_ms'].shape, (2,))
|
||||
|
||||
writer = writer_holder['writer']
|
||||
self.assertTrue(writer.released)
|
||||
self.assertEqual(len(writer.frames), 2)
|
||||
np.testing.assert_array_equal(writer.frames[0], np.zeros((6, 8, 3), dtype=np.uint8))
|
||||
np.testing.assert_array_equal(writer.frames[1], np.full((6, 8, 3), 1, dtype=np.uint8))
|
||||
|
||||
with summary_path.open('r', encoding='utf-8') as fh:
|
||||
saved_summary = json.load(fh)
|
||||
self.assertEqual(saved_summary['artifacts']['trajectory_npz'], str(trajectory_path))
|
||||
self.assertEqual(saved_summary['artifacts']['video_mp4'], str(video_path))
|
||||
self.assertEqual(saved_summary['episode_rewards'], [3.0])
|
||||
self.assertAlmostEqual(summary['avg_reward'], 3.0)
|
||||
self.assertIn('avg_obs_read_time_ms', summary)
|
||||
self.assertIn('avg_env_step_time_ms', summary)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user