diff --git a/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml b/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml
new file mode 100644
index 0000000..38c21f8
--- /dev/null
+++ b/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml
new file mode 100644
index 0000000..0545799
--- /dev/null
+++ b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py
index 60c6364..63f489f 100644
--- a/roboimi/envs/double_air_insert_env.py
+++ b/roboimi/envs/double_air_insert_env.py
@@ -1,13 +1,71 @@
+import copy as cp
+import time
+
+import numpy as np
+
+from roboimi.envs.double_base import DualDianaMed
from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
+RING_JOINT_NAME = "ring_block_joint"
+BAR_JOINT_NAME = "bar_block_joint"
+REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
+
+
+def _set_free_joint_pose(joint, position, quat):
+ joint.qpos[:3] = np.asarray(position, dtype=np.float64)
+ joint.qpos[3:7] = np.asarray(quat, dtype=np.float64)
+
+
+def set_ring_bar_task_state(mj_data, task_state):
+ if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS:
+ raise ValueError(
+ "task_state must be an ordered dict-like mapping with keys "
+ "ring_pos, ring_quat, bar_pos, bar_quat"
+ )
+
+ _set_free_joint_pose(
+ mj_data.joint(RING_JOINT_NAME),
+ task_state["ring_pos"],
+ task_state["ring_quat"],
+ )
+ _set_free_joint_pose(
+ mj_data.joint(BAR_JOINT_NAME),
+ task_state["bar_pos"],
+ task_state["bar_quat"],
+ )
+
+
+def get_ring_bar_env_state(mj_data):
+ ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64))
+ bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64))
+ return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
+
+
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
def reset(self, task_state):
- required_keys = {"ring_pos", "ring_quat", "bar_pos", "bar_quat"}
- if not isinstance(task_state, dict) or set(task_state.keys()) != required_keys:
- raise ValueError(
- "task_state must be a dict with ring_pos, ring_quat, bar_pos, and bar_quat"
- )
+ set_ring_bar_task_state(self.mj_data, task_state)
+ DualDianaMed.reset(self)
+ self.top = None
+ self.angle = None
+ self.r_vis = None
+ self.front = None
+ self.cam_flage = True
+ while self.cam_flage:
+ if (
+ type(self.top) == type(None)
+ or type(self.angle) == type(None)
+ or type(self.r_vis) == type(None)
+ or type(self.front) == type(None)
+ ):
+ time.sleep(0.001)
+ else:
+ self.cam_flage = False
+
+ def get_env_state(self):
+ return get_ring_bar_env_state(self.mj_data)
+
+ def _get_reward(self):
raise NotImplementedError(
- "sim_air_insert_ring_bar reset wiring is intentionally deferred beyond Task 1"
+ "Task 2 wires reset/state only; reward logic is implemented in a later task."
)
diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py
index 99d7c42..3f5237c 100644
--- a/tests/test_air_insert_env.py
+++ b/tests/test_air_insert_env.py
@@ -97,5 +97,82 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
)
+class AirInsertResetAndStateHelpersTest(unittest.TestCase):
+ def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
+ air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
+ setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
+ self.assertIsNotNone(
+ setter,
+ "Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
+ )
+
+ ring_qpos = np.zeros(7, dtype=np.float64)
+ bar_qpos = np.zeros(7, dtype=np.float64)
+
+ class _FakeJoint:
+ def __init__(self, qpos):
+ self.qpos = qpos
+
+ class _FakeData:
+ def joint(self, name):
+ if name == "ring_block_joint":
+ return _FakeJoint(ring_qpos)
+ if name == "bar_block_joint":
+ return _FakeJoint(bar_qpos)
+ raise AssertionError(f"Unexpected joint name: {name}")
+
+ task_state = {
+ "ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
+ "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ "bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
+ "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ }
+
+ setter(_FakeData(), task_state)
+
+ np.testing.assert_array_equal(
+ ring_qpos,
+ np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ )
+ np.testing.assert_array_equal(
+ bar_qpos,
+ np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
+ )
+
+ def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
+ air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
+ getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
+ self.assertIsNotNone(
+ getter,
+ "Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
+ )
+
+ ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
+ bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
+
+ class _FakeJoint:
+ def __init__(self, qpos):
+ self.qpos = qpos
+
+ class _FakeData:
+ def joint(self, name):
+ if name == "ring_block_joint":
+ return _FakeJoint(ring_qpos)
+ if name == "bar_block_joint":
+ return _FakeJoint(bar_qpos)
+ raise AssertionError(f"Unexpected joint name: {name}")
+
+ env_state = getter(_FakeData())
+
+ self.assertEqual(env_state.shape, (14,))
+ np.testing.assert_array_equal(
+ env_state,
+ np.array(
+ [-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
+ dtype=np.float64,
+ ),
+ )
+
+
if __name__ == "__main__":
unittest.main()