feat(env): register sim air insert ring bar task
This commit is contained in:
@@ -36,8 +36,8 @@ class _FakeEnv:
|
||||
self.render_calls = 0
|
||||
self.reset_calls = []
|
||||
|
||||
def reset(self, box_pos):
|
||||
self.reset_calls.append(np.array(box_pos))
|
||||
def reset(self, task_state):
|
||||
self.reset_calls.append(task_state)
|
||||
|
||||
def _get_image_obs(self):
|
||||
self.image_obs_calls += 1
|
||||
@@ -254,6 +254,69 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self):
|
||||
self.assertTrue(
|
||||
hasattr(eval_vla, "sample_air_insert_ring_bar_state"),
|
||||
"Expected eval_vla to expose the new ring/bar reset sampler",
|
||||
)
|
||||
|
||||
fake_env = _FakeEnv()
|
||||
fake_agent = _FakeAgent()
|
||||
sampled_task_state = {
|
||||
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
"agent": {},
|
||||
"eval": {
|
||||
"ckpt_path": "checkpoints/vla_model_best.pt",
|
||||
"num_episodes": 1,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_air_insert_ring_bar",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
"verbose_action": False,
|
||||
"headless": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
with mock.patch.object(
|
||||
eval_vla,
|
||||
"load_checkpoint",
|
||||
return_value=(fake_agent, None),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"make_sim_env",
|
||||
return_value=fake_env,
|
||||
) as make_env, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_air_insert_ring_bar_state",
|
||||
return_value=sampled_task_state,
|
||||
) as ring_bar_sampler, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
) as execute_policy_action, mock.patch.object(
|
||||
eval_vla,
|
||||
"tqdm",
|
||||
side_effect=lambda iterable, **kwargs: iterable,
|
||||
):
|
||||
eval_vla._run_eval(cfg)
|
||||
|
||||
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True)
|
||||
ring_bar_sampler.assert_called_once_with()
|
||||
execute_policy_action.assert_called_once()
|
||||
self.assertEqual(fake_env.reset_calls, [sampled_task_state])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user