From fce6839daa3d5bf22528525ec63af4851fa25db9 Mon Sep 17 00:00:00 2001 From: Logic Date: Thu, 23 Apr 2026 17:05:33 +0800 Subject: [PATCH] feat(env): register sim air insert ring bar task --- roboimi/assets/robots/diana_med.py | 36 +++++++++ roboimi/demos/vla_scripts/eval_vla.py | 17 ++++- roboimi/envs/double_air_insert_env.py | 13 ++++ roboimi/envs/double_pos_ctrl_env.py | 12 +++ roboimi/utils/act_ex_utils.py | 21 +++++- roboimi/utils/constants.py | 7 ++ tests/test_air_insert_env.py | 101 ++++++++++++++++++++++++++ tests/test_eval_vla_headless.py | 67 ++++++++++++++++- tests/test_robot_asset_paths.py | 44 ++++++++++- 9 files changed, 311 insertions(+), 7 deletions(-) create mode 100644 roboimi/envs/double_air_insert_env.py create mode 100644 tests/test_air_insert_env.py diff --git a/roboimi/assets/robots/diana_med.py b/roboimi/assets/robots/diana_med.py index 0c26ca0..04ff249 100644 --- a/roboimi/assets/robots/diana_med.py +++ b/roboimi/assets/robots/diana_med.py @@ -90,4 +90,40 @@ class BiDianaMed(ArmBase): def init_qpos(self): """ Robot's init joint position. """ return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0]) + + +class BiDianaMedRingBar(ArmBase): + def __init__(self): + super().__init__( + name="Bidiana_ring_bar", + urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf", + xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml", + gripper=None + ) + self.left_arm = self.Arm(self, 'single', self.urdf_path) + self.left_arm.set_Arm_base_link('left_base_link') + self.left_arm.set_Arm_ee_link('left_link7') + self.left_arm.InitKDL + self.left_arm.joint_index = ['l_j1','l_j2','l_j3','l_j4','l_j5','l_j6','l_j7'] + self.left_arm.gripper_index = ['l_finger_joint_left','r_finger_joint_left'] + self.left_arm.actuator_index = ['a1_l','a2_l','a3_l','a4_l','a5_l','a6_l','a7_l','gripper_left'] + self.left_arm.setArmInitPose(self.init_qpos) + self.arms.append(self.left_arm) + self.right_arm = self.Arm(self,'single', self.urdf_path) + self.right_arm.set_Arm_base_link('right_base_link') + self.right_arm.set_Arm_ee_link('right_link7') + self.right_arm.InitKDL + self.right_arm.joint_index = ['r_j1','r_j2','r_j3','r_j4','r_j5','r_j6','r_j7'] + self.right_arm.gripper_index = ['l_finger_joint_right','r_finger_joint_right'] + self.right_arm.actuator_index = ['a1_r','a2_r','a3_r','a4_r','a5_r','a6_r','a7_r','gripper_right'] + self.right_arm.setArmInitPose(self.init_qpos) + self.arms.append(self.right_arm) + self.jnt_num = self.left_arm.jnt_num + self.right_arm.jnt_num + self.kp = 500 * np.ones(self.jnt_num) + self.kd = 44.57 * np.ones(self.jnt_num) + + @property + def init_qpos(self): + """ Robot's init joint position. """ + return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0]) diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index de7e7d7..265e36a 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -26,7 +26,10 @@ from hydra.utils import instantiate from einops import rearrange from roboimi.envs.double_pos_ctrl_env import make_sim_env -from roboimi.utils.act_ex_utils import sample_transfer_pose +from roboimi.utils.act_ex_utils import ( + sample_air_insert_ring_bar_state, + sample_transfer_pose, +) from roboimi.vla.eval_utils import execute_policy_action sys.path.append(os.getcwd()) @@ -485,6 +488,14 @@ def _close_env(env): viewer.close() +def _sample_task_reset_state(task_name: str): + if task_name == 'sim_air_insert_ring_bar': + return sample_air_insert_ring_bar_state() + if 'sim_transfer' in task_name: + return sample_transfer_pose() + raise NotImplementedError(f'Unsupported eval task reset sampling: {task_name}') + + def _run_eval(cfg: DictConfig): """ 使用 agent 内置队列管理的简化版 VLA 评估 @@ -549,8 +560,8 @@ def _run_eval(cfg: DictConfig): print(f"回合 {episode_idx + 1}/{eval_cfg.num_episodes}") print(f"{'='*60}\n") - box_pos = sample_transfer_pose() - env.reset(box_pos) + task_state = _sample_task_reset_state(str(eval_cfg.task_name)) + env.reset(task_state) # 为新回合重置 agent 队列 agent.reset() diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py new file mode 100644 index 0000000..60c6364 --- /dev/null +++ b/roboimi/envs/double_air_insert_env.py @@ -0,0 +1,13 @@ +from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl + + +class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): + def reset(self, task_state): + required_keys = {"ring_pos", "ring_quat", "bar_pos", "bar_quat"} + if not isinstance(task_state, dict) or set(task_state.keys()) != required_keys: + raise ValueError( + "task_state must be a dict with ring_pos, ring_quat, bar_pos, and bar_quat" + ) + raise NotImplementedError( + "sim_air_insert_ring_bar reset wiring is intentionally deferred beyond Task 1" + ) diff --git a/roboimi/envs/double_pos_ctrl_env.py b/roboimi/envs/double_pos_ctrl_env.py index 78cb1a6..31e8c86 100644 --- a/roboimi/envs/double_pos_ctrl_env.py +++ b/roboimi/envs/double_pos_ctrl_env.py @@ -134,6 +134,18 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): def make_sim_env(task_name, headless=False): + if task_name == 'sim_air_insert_ring_bar': + from roboimi.assets.robots.diana_med import BiDianaMedRingBar + from roboimi.envs.double_air_insert_env import DualDianaMed_Air_Insert + + env = DualDianaMed_Air_Insert( + robot=BiDianaMedRingBar(), + is_render=not headless, + control_freq=30, + is_interpolate=True, + cam_view='angle' + ) + return env if 'sim_transfer' in task_name: from roboimi.assets.robots.diana_med import BiDianaMed env = DualDianaMed_Pos_Ctrl( diff --git a/roboimi/utils/act_ex_utils.py b/roboimi/utils/act_ex_utils.py index 2682f5f..6afc0bb 100644 --- a/roboimi/utils/act_ex_utils.py +++ b/roboimi/utils/act_ex_utils.py @@ -1,5 +1,6 @@ import numpy as np + def sample_insertion_pose(): # Peg x_range = [0.1, 0.2] @@ -35,4 +36,22 @@ def sample_transfer_pose(): box_position = np.random.uniform(ranges[:, 0], ranges[:, 1]) - return box_position \ No newline at end of file + return box_position + + +def sample_air_insert_ring_bar_state(): + ring_position = np.random.uniform( + low=np.array([-0.20, 0.70, 0.47], dtype=np.float32), + high=np.array([-0.05, 1.00, 0.47], dtype=np.float32), + ) + bar_position = np.random.uniform( + low=np.array([0.05, 0.70, 0.47], dtype=np.float32), + high=np.array([0.20, 1.00, 0.47], dtype=np.float32), + ) + fixed_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + return { + "ring_pos": ring_position.astype(np.float32, copy=False), + "ring_quat": fixed_quat.copy(), + "bar_pos": bar_position.astype(np.float32, copy=False), + "bar_quat": fixed_quat.copy(), + } diff --git a/roboimi/utils/constants.py b/roboimi/utils/constants.py index 2f0d41b..10158e7 100644 --- a/roboimi/utils/constants.py +++ b/roboimi/utils/constants.py @@ -23,6 +23,13 @@ SIM_TASK_CONFIGS = { 'camera_names': ['top','r_vis','front'], 'xml_dir': HOME_PATH + '/assets' }, + 'sim_air_insert_ring_bar': { + 'dataset_dir': DATASET_DIR + '/sim_air_insert_ring_bar', + 'num_episodes': 20, + 'episode_len': 700, + 'camera_names': ['top', 'r_vis', 'front'], + 'xml_dir': HOME_PATH + '/assets' + }, } diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py new file mode 100644 index 0000000..99d7c42 --- /dev/null +++ b/tests/test_air_insert_env.py @@ -0,0 +1,101 @@ +import importlib +import unittest +from unittest import mock + +import numpy as np + +from roboimi.envs.double_pos_ctrl_env import make_sim_env +from roboimi.utils import act_ex_utils +from roboimi.utils.constants import SIM_TASK_CONFIGS + + +class AirInsertTaskRegistrationTest(unittest.TestCase): + def test_sim_task_configs_registers_air_insert_ring_bar(self): + self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS) + + def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self): + sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) + self.assertIsNotNone( + sampler, + "Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()", + ) + + task_state = sampler() + + self.assertEqual( + list(task_state.keys()), + ["ring_pos", "ring_quat", "bar_pos", "bar_quat"], + ) + self.assertEqual(task_state["ring_pos"].shape, (3,)) + self.assertEqual(task_state["ring_quat"].shape, (4,)) + self.assertEqual(task_state["bar_pos"].shape, (3,)) + self.assertEqual(task_state["bar_quat"].shape, (4,)) + + def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self): + sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) + self.assertIsNotNone( + sampler, + "Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()", + ) + + task_state = sampler() + + np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0])) + np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0])) + self.assertGreaterEqual(task_state["ring_pos"][0], -0.20) + self.assertLessEqual(task_state["ring_pos"][0], -0.05) + self.assertGreaterEqual(task_state["ring_pos"][1], 0.70) + self.assertLessEqual(task_state["ring_pos"][1], 1.00) + self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47) + self.assertGreaterEqual(task_state["bar_pos"][0], 0.05) + self.assertLessEqual(task_state["bar_pos"][0], 0.20) + self.assertGreaterEqual(task_state["bar_pos"][1], 0.70) + self.assertLessEqual(task_state["bar_pos"][1], 1.00) + self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47) + + def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self): + try: + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + except Exception as exc: + self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}") + + air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) + self.assertIsNotNone( + air_insert_cls, + "Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert", + ) + + diana_med = importlib.import_module("roboimi.assets.robots.diana_med") + ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None) + self.assertIsNotNone( + ring_bar_robot_cls, + "Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar", + ) + + fake_env = object() + with mock.patch.object( + diana_med, + "BiDianaMedRingBar", + return_value="robot", + ), mock.patch.object( + air_insert_env, + "DualDianaMed_Air_Insert", + return_value=fake_env, + ) as env_cls: + try: + env = make_sim_env("sim_air_insert_ring_bar", headless=True) + except Exception as exc: + self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}") + + self.assertIs(env, fake_env) + env_cls.assert_called_once_with( + robot="robot", + is_render=False, + control_freq=30, + is_interpolate=True, + cam_view="angle", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_eval_vla_headless.py b/tests/test_eval_vla_headless.py index e6f4abb..da11bd2 100644 --- a/tests/test_eval_vla_headless.py +++ b/tests/test_eval_vla_headless.py @@ -36,8 +36,8 @@ class _FakeEnv: self.render_calls = 0 self.reset_calls = [] - def reset(self, box_pos): - self.reset_calls.append(np.array(box_pos)) + def reset(self, task_state): + self.reset_calls.append(task_state) def _get_image_obs(self): self.image_obs_calls += 1 @@ -254,6 +254,69 @@ class EvalVLAHeadlessTest(unittest.TestCase): self.assertAlmostEqual(summary["avg_reward"], 3.75) self.assertEqual(summary["num_episodes"], 2) + def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self): + self.assertTrue( + hasattr(eval_vla, "sample_air_insert_ring_bar_state"), + "Expected eval_vla to expose the new ring/bar reset sampler", + ) + + fake_env = _FakeEnv() + fake_agent = _FakeAgent() + sampled_task_state = { + "ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32), + "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32), + "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + cfg = OmegaConf.create( + { + "agent": {}, + "eval": { + "ckpt_path": "checkpoints/vla_model_best.pt", + "num_episodes": 1, + "max_timesteps": 1, + "device": "cpu", + "task_name": "sim_air_insert_ring_bar", + "camera_names": ["front"], + "use_smoothing": False, + "smooth_alpha": 0.3, + "verbose_action": False, + "headless": True, + }, + } + ) + + with mock.patch.object( + eval_vla, + "load_checkpoint", + return_value=(fake_agent, None), + ), mock.patch.object( + eval_vla, + "make_sim_env", + return_value=fake_env, + ) as make_env, mock.patch.object( + eval_vla, + "sample_air_insert_ring_bar_state", + return_value=sampled_task_state, + ) as ring_bar_sampler, mock.patch.object( + eval_vla, + "sample_transfer_pose", + side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"), + ), mock.patch.object( + eval_vla, + "execute_policy_action", + ) as execute_policy_action, mock.patch.object( + eval_vla, + "tqdm", + side_effect=lambda iterable, **kwargs: iterable, + ): + eval_vla._run_eval(cfg) + + make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True) + ring_bar_sampler.assert_called_once_with() + execute_policy_action.assert_called_once() + self.assertEqual(fake_env.reset_calls, [sampled_task_state]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_robot_asset_paths.py b/tests/test_robot_asset_paths.py index 8412192..0a1e5de 100644 --- a/tests/test_robot_asset_paths.py +++ b/tests/test_robot_asset_paths.py @@ -4,7 +4,7 @@ import unittest from pathlib import Path from unittest import mock -from roboimi.assets.robots.diana_med import BiDianaMed +from roboimi.assets.robots import diana_med class _FakeKDL: @@ -24,6 +24,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase): _FakeKDL.reset_calls = [] def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self): + BiDianaMed = diana_med.BiDianaMed repo_root = Path(__file__).resolve().parents[1] expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml' expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf' @@ -58,6 +59,47 @@ class RobotAssetPathResolutionTest(unittest.TestCase): self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf}) self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls)) + def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self): + BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None) + self.assertIsNotNone( + BiDianaMedRingBar, + 'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar', + ) + + repo_root = Path(__file__).resolve().parents[1] + expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml' + expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf' + xml_calls = [] + + def fake_from_xml_path(*, filename, assets=None): + xml_calls.append((filename, assets)) + return object() + + with tempfile.TemporaryDirectory() as tempdir: + previous_cwd = os.getcwd() + try: + os.chdir(tempdir) + with mock.patch( + 'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path', + side_effect=fake_from_xml_path, + ), mock.patch( + 'roboimi.assets.robots.arm_base.mujoco.MjData', + return_value=object(), + ), mock.patch( + 'roboimi.assets.robots.arm_base.KDL_utils', + _FakeKDL, + ): + BiDianaMedRingBar() + finally: + os.chdir(previous_cwd) + + self.assertEqual(len(xml_calls), 1) + self.assertEqual(Path(xml_calls[0][0]), expected_xml) + self.assertTrue(Path(xml_calls[0][0]).is_absolute()) + self.assertGreaterEqual(len(_FakeKDL.init_calls), 2) + self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf}) + self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls)) + if __name__ == '__main__': unittest.main()