feat(env): register sim air insert ring bar task

This commit is contained in:
Logic
2026-04-23 17:05:33 +08:00
parent 3eb1a83940
commit fce6839daa
9 changed files with 311 additions and 7 deletions

View File

@@ -0,0 +1,101 @@
import importlib
import unittest
from unittest import mock
import numpy as np
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils import act_ex_utils
from roboimi.utils.constants import SIM_TASK_CONFIGS
class AirInsertTaskRegistrationTest(unittest.TestCase):
def test_sim_task_configs_registers_air_insert_ring_bar(self):
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
task_state = sampler()
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
self.assertEqual(task_state["ring_pos"].shape, (3,))
self.assertEqual(task_state["ring_quat"].shape, (4,))
self.assertEqual(task_state["bar_pos"].shape, (3,))
self.assertEqual(task_state["bar_quat"].shape, (4,))
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
task_state = sampler()
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
try:
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
except Exception as exc:
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone(
air_insert_cls,
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
)
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
self.assertIsNotNone(
ring_bar_robot_cls,
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
)
fake_env = object()
with mock.patch.object(
diana_med,
"BiDianaMedRingBar",
return_value="robot",
), mock.patch.object(
air_insert_env,
"DualDianaMed_Air_Insert",
return_value=fake_env,
) as env_cls:
try:
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
except Exception as exc:
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
self.assertIs(env, fake_env)
env_cls.assert_called_once_with(
robot="robot",
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -36,8 +36,8 @@ class _FakeEnv:
self.render_calls = 0
self.reset_calls = []
def reset(self, box_pos):
self.reset_calls.append(np.array(box_pos))
def reset(self, task_state):
self.reset_calls.append(task_state)
def _get_image_obs(self):
self.image_obs_calls += 1
@@ -254,6 +254,69 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self):
self.assertTrue(
hasattr(eval_vla, "sample_air_insert_ring_bar_state"),
"Expected eval_vla to expose the new ring/bar reset sampler",
)
fake_env = _FakeEnv()
fake_agent = _FakeAgent()
sampled_task_state = {
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
cfg = OmegaConf.create(
{
"agent": {},
"eval": {
"ckpt_path": "checkpoints/vla_model_best.pt",
"num_episodes": 1,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_air_insert_ring_bar",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
"verbose_action": False,
"headless": True,
},
}
)
with mock.patch.object(
eval_vla,
"load_checkpoint",
return_value=(fake_agent, None),
), mock.patch.object(
eval_vla,
"make_sim_env",
return_value=fake_env,
) as make_env, mock.patch.object(
eval_vla,
"sample_air_insert_ring_bar_state",
return_value=sampled_task_state,
) as ring_bar_sampler, mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"),
), mock.patch.object(
eval_vla,
"execute_policy_action",
) as execute_policy_action, mock.patch.object(
eval_vla,
"tqdm",
side_effect=lambda iterable, **kwargs: iterable,
):
eval_vla._run_eval(cfg)
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True)
ring_bar_sampler.assert_called_once_with()
execute_policy_action.assert_called_once()
self.assertEqual(fake_env.reset_calls, [sampled_task_state])
if __name__ == "__main__":
unittest.main()

View File

@@ -4,7 +4,7 @@ import unittest
from pathlib import Path
from unittest import mock
from roboimi.assets.robots.diana_med import BiDianaMed
from roboimi.assets.robots import diana_med
class _FakeKDL:
@@ -24,6 +24,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
_FakeKDL.reset_calls = []
def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self):
BiDianaMed = diana_med.BiDianaMed
repo_root = Path(__file__).resolve().parents[1]
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml'
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
@@ -58,6 +59,47 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self):
BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None)
self.assertIsNotNone(
BiDianaMedRingBar,
'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar',
)
repo_root = Path(__file__).resolve().parents[1]
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml'
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
xml_calls = []
def fake_from_xml_path(*, filename, assets=None):
xml_calls.append((filename, assets))
return object()
with tempfile.TemporaryDirectory() as tempdir:
previous_cwd = os.getcwd()
try:
os.chdir(tempdir)
with mock.patch(
'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path',
side_effect=fake_from_xml_path,
), mock.patch(
'roboimi.assets.robots.arm_base.mujoco.MjData',
return_value=object(),
), mock.patch(
'roboimi.assets.robots.arm_base.KDL_utils',
_FakeKDL,
):
BiDianaMedRingBar()
finally:
os.chdir(previous_cwd)
self.assertEqual(len(xml_calls), 1)
self.assertEqual(Path(xml_calls[0][0]), expected_xml)
self.assertTrue(Path(xml_calls[0][0]).is_absolute())
self.assertGreaterEqual(len(_FakeKDL.init_calls), 2)
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
if __name__ == '__main__':
unittest.main()