feat(env): register sim air insert ring bar task

This commit is contained in:
Logic
2026-04-23 17:05:33 +08:00
parent 3eb1a83940
commit fce6839daa
9 changed files with 311 additions and 7 deletions

View File

@@ -36,8 +36,8 @@ class _FakeEnv:
self.render_calls = 0
self.reset_calls = []
def reset(self, box_pos):
self.reset_calls.append(np.array(box_pos))
def reset(self, task_state):
self.reset_calls.append(task_state)
def _get_image_obs(self):
self.image_obs_calls += 1
@@ -254,6 +254,69 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self):
self.assertTrue(
hasattr(eval_vla, "sample_air_insert_ring_bar_state"),
"Expected eval_vla to expose the new ring/bar reset sampler",
)
fake_env = _FakeEnv()
fake_agent = _FakeAgent()
sampled_task_state = {
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 1,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_air_insert_ring_bar",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
) as make_env, mock.patch.object(
eval_vla,
"sample_air_insert_ring_bar_state",
return_value=sampled_task_state,
) as ring_bar_sampler, mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"),
), mock.patch.object(
eval_vla,
"execute_policy_action",
) as execute_policy_action, mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
eval_vla._run_eval(cfg)
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True)
ring_bar_sampler.assert_called_once_with()
execute_policy_action.assert_called_once()
self.assertEqual(fake_env.reset_calls, [sampled_task_state])
if __name__ == "__main__":
unittest.main()