From f1ede7690f79b7efb9b928fa32729d0f00896331 Mon Sep 17 00:00:00 2001 From: Logic Date: Thu, 23 Apr 2026 17:32:43 +0800 Subject: [PATCH] feat(scene): add ring and bar insertion scene assets --- .../DianaMed/bi_diana_ring_bar_ee.xml | 6 ++ .../DianaMed/ring_bar_objects.xml | 28 +++++++ roboimi/envs/double_air_insert_env.py | 70 +++++++++++++++-- tests/test_air_insert_env.py | 77 +++++++++++++++++++ 4 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml create mode 100644 roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml diff --git a/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml b/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml new file mode 100644 index 0000000..38c21f8 --- /dev/null +++ b/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml new file mode 100644 index 0000000..0545799 --- /dev/null +++ b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py index 60c6364..63f489f 100644 --- a/roboimi/envs/double_air_insert_env.py +++ b/roboimi/envs/double_air_insert_env.py @@ -1,13 +1,71 @@ +import copy as cp +import time + +import numpy as np + +from roboimi.envs.double_base import DualDianaMed from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl +RING_JOINT_NAME = "ring_block_joint" +BAR_JOINT_NAME = "bar_block_joint" +REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat") + + +def _set_free_joint_pose(joint, position, quat): + joint.qpos[:3] = np.asarray(position, dtype=np.float64) + joint.qpos[3:7] = np.asarray(quat, dtype=np.float64) + + +def set_ring_bar_task_state(mj_data, task_state): + if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS: + raise ValueError( + "task_state must be an ordered dict-like mapping with keys " + "ring_pos, ring_quat, bar_pos, bar_quat" + ) + + _set_free_joint_pose( + mj_data.joint(RING_JOINT_NAME), + task_state["ring_pos"], + task_state["ring_quat"], + ) + _set_free_joint_pose( + mj_data.joint(BAR_JOINT_NAME), + task_state["bar_pos"], + task_state["bar_quat"], + ) + + +def get_ring_bar_env_state(mj_data): + ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64)) + bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64)) + return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64) + + class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): def reset(self, task_state): - required_keys = {"ring_pos", "ring_quat", "bar_pos", "bar_quat"} - if not isinstance(task_state, dict) or set(task_state.keys()) != required_keys: - raise ValueError( - "task_state must be a dict with ring_pos, ring_quat, bar_pos, and bar_quat" - ) + set_ring_bar_task_state(self.mj_data, task_state) + DualDianaMed.reset(self) + self.top = None + self.angle = None + self.r_vis = None + self.front = None + self.cam_flage = True + while self.cam_flage: + if ( + type(self.top) == type(None) + or type(self.angle) == type(None) + or type(self.r_vis) == type(None) + or type(self.front) == type(None) + ): + time.sleep(0.001) + else: + self.cam_flage = False + + def get_env_state(self): + return get_ring_bar_env_state(self.mj_data) + + def _get_reward(self): raise NotImplementedError( - "sim_air_insert_ring_bar reset wiring is intentionally deferred beyond Task 1" + "Task 2 wires reset/state only; reward logic is implemented in a later task." ) diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py index 99d7c42..3f5237c 100644 --- a/tests/test_air_insert_env.py +++ b/tests/test_air_insert_env.py @@ -97,5 +97,82 @@ class AirInsertTaskRegistrationTest(unittest.TestCase): ) +class AirInsertResetAndStateHelpersTest(unittest.TestCase): + def test_set_ring_bar_task_state_writes_free_joint_qpos(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + setter = getattr(air_insert_env, "set_ring_bar_task_state", None) + self.assertIsNotNone( + setter, + "Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state", + ) + + ring_qpos = np.zeros(7, dtype=np.float64) + bar_qpos = np.zeros(7, dtype=np.float64) + + class _FakeJoint: + def __init__(self, qpos): + self.qpos = qpos + + class _FakeData: + def joint(self, name): + if name == "ring_block_joint": + return _FakeJoint(ring_qpos) + if name == "bar_block_joint": + return _FakeJoint(bar_qpos) + raise AssertionError(f"Unexpected joint name: {name}") + + task_state = { + "ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64), + "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), + "bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64), + "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), + } + + setter(_FakeData(), task_state) + + np.testing.assert_array_equal( + ring_qpos, + np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + ) + np.testing.assert_array_equal( + bar_qpos, + np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + ) + + def test_get_ring_bar_env_state_returns_stable_14d_vector(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + getter = getattr(air_insert_env, "get_ring_bar_env_state", None) + self.assertIsNotNone( + getter, + "Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state", + ) + + ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) + bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) + + class _FakeJoint: + def __init__(self, qpos): + self.qpos = qpos + + class _FakeData: + def joint(self, name): + if name == "ring_block_joint": + return _FakeJoint(ring_qpos) + if name == "bar_block_joint": + return _FakeJoint(bar_qpos) + raise AssertionError(f"Unexpected joint name: {name}") + + env_state = getter(_FakeData()) + + self.assertEqual(env_state.shape, (14,)) + np.testing.assert_array_equal( + env_state, + np.array( + [-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], + dtype=np.float64, + ), + ) + + if __name__ == "__main__": unittest.main()