feat(scene): add ring and bar insertion scene assets
This commit is contained in:
@@ -0,0 +1,6 @@
|
|||||||
|
<mujoco model="bi_diana_ring_bar">
|
||||||
|
<include file="./empty_world.xml" />
|
||||||
|
<include file="./table_square.xml" />
|
||||||
|
<include file="./ring_bar_objects.xml" />
|
||||||
|
<include file="./BiDianaMed_rethink.xml" />
|
||||||
|
</mujoco>
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
<mujoco model="ring_bar_objects">
|
||||||
|
<worldbody>
|
||||||
|
<body name="ring_block" pos="-0.12 0.90 0.47">
|
||||||
|
<joint name="ring_block_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.08" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom name="ring_block_north" type="box" pos="0 0.025 0" size="0.034 0.009 0.009"
|
||||||
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
|
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
||||||
|
<geom name="ring_block_south" type="box" pos="0 -0.025 0" size="0.034 0.009 0.009"
|
||||||
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
|
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
||||||
|
<geom name="ring_block_east" type="box" pos="0.025 0 0" size="0.009 0.016 0.009"
|
||||||
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
|
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
||||||
|
<geom name="ring_block_west" type="box" pos="-0.025 0 0" size="0.009 0.016 0.009"
|
||||||
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
|
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
||||||
|
</body>
|
||||||
|
|
||||||
|
<body name="bar_block" pos="0.12 0.90 0.47">
|
||||||
|
<joint name="bar_block_joint" type="free" frictionloss="0.01" />
|
||||||
|
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||||
|
<geom name="bar_block" type="box" pos="0 0 0" size="0.045 0.009 0.009"
|
||||||
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
|
friction="1 0.005 0.0001" rgba="0 0.7 0.2 1" />
|
||||||
|
</body>
|
||||||
|
</worldbody>
|
||||||
|
</mujoco>
|
||||||
@@ -1,13 +1,71 @@
|
|||||||
|
import copy as cp
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from roboimi.envs.double_base import DualDianaMed
|
||||||
from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
|
from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
|
||||||
|
|
||||||
|
|
||||||
|
RING_JOINT_NAME = "ring_block_joint"
|
||||||
|
BAR_JOINT_NAME = "bar_block_joint"
|
||||||
|
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
|
||||||
|
|
||||||
|
|
||||||
|
def _set_free_joint_pose(joint, position, quat):
|
||||||
|
joint.qpos[:3] = np.asarray(position, dtype=np.float64)
|
||||||
|
joint.qpos[3:7] = np.asarray(quat, dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def set_ring_bar_task_state(mj_data, task_state):
|
||||||
|
if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS:
|
||||||
|
raise ValueError(
|
||||||
|
"task_state must be an ordered dict-like mapping with keys "
|
||||||
|
"ring_pos, ring_quat, bar_pos, bar_quat"
|
||||||
|
)
|
||||||
|
|
||||||
|
_set_free_joint_pose(
|
||||||
|
mj_data.joint(RING_JOINT_NAME),
|
||||||
|
task_state["ring_pos"],
|
||||||
|
task_state["ring_quat"],
|
||||||
|
)
|
||||||
|
_set_free_joint_pose(
|
||||||
|
mj_data.joint(BAR_JOINT_NAME),
|
||||||
|
task_state["bar_pos"],
|
||||||
|
task_state["bar_quat"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ring_bar_env_state(mj_data):
|
||||||
|
ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||||
|
bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||||
|
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
||||||
def reset(self, task_state):
|
def reset(self, task_state):
|
||||||
required_keys = {"ring_pos", "ring_quat", "bar_pos", "bar_quat"}
|
set_ring_bar_task_state(self.mj_data, task_state)
|
||||||
if not isinstance(task_state, dict) or set(task_state.keys()) != required_keys:
|
DualDianaMed.reset(self)
|
||||||
raise ValueError(
|
self.top = None
|
||||||
"task_state must be a dict with ring_pos, ring_quat, bar_pos, and bar_quat"
|
self.angle = None
|
||||||
)
|
self.r_vis = None
|
||||||
|
self.front = None
|
||||||
|
self.cam_flage = True
|
||||||
|
while self.cam_flage:
|
||||||
|
if (
|
||||||
|
type(self.top) == type(None)
|
||||||
|
or type(self.angle) == type(None)
|
||||||
|
or type(self.r_vis) == type(None)
|
||||||
|
or type(self.front) == type(None)
|
||||||
|
):
|
||||||
|
time.sleep(0.001)
|
||||||
|
else:
|
||||||
|
self.cam_flage = False
|
||||||
|
|
||||||
|
def get_env_state(self):
|
||||||
|
return get_ring_bar_env_state(self.mj_data)
|
||||||
|
|
||||||
|
def _get_reward(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"sim_air_insert_ring_bar reset wiring is intentionally deferred beyond Task 1"
|
"Task 2 wires reset/state only; reward logic is implemented in a later task."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -97,5 +97,82 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||||
|
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
setter,
|
||||||
|
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
|
||||||
|
)
|
||||||
|
|
||||||
|
ring_qpos = np.zeros(7, dtype=np.float64)
|
||||||
|
bar_qpos = np.zeros(7, dtype=np.float64)
|
||||||
|
|
||||||
|
class _FakeJoint:
|
||||||
|
def __init__(self, qpos):
|
||||||
|
self.qpos = qpos
|
||||||
|
|
||||||
|
class _FakeData:
|
||||||
|
def joint(self, name):
|
||||||
|
if name == "ring_block_joint":
|
||||||
|
return _FakeJoint(ring_qpos)
|
||||||
|
if name == "bar_block_joint":
|
||||||
|
return _FakeJoint(bar_qpos)
|
||||||
|
raise AssertionError(f"Unexpected joint name: {name}")
|
||||||
|
|
||||||
|
task_state = {
|
||||||
|
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
|
||||||
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||||
|
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
|
||||||
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||||
|
}
|
||||||
|
|
||||||
|
setter(_FakeData(), task_state)
|
||||||
|
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
ring_qpos,
|
||||||
|
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||||
|
)
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
bar_qpos,
|
||||||
|
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
getter,
|
||||||
|
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
|
||||||
|
)
|
||||||
|
|
||||||
|
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||||
|
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||||
|
|
||||||
|
class _FakeJoint:
|
||||||
|
def __init__(self, qpos):
|
||||||
|
self.qpos = qpos
|
||||||
|
|
||||||
|
class _FakeData:
|
||||||
|
def joint(self, name):
|
||||||
|
if name == "ring_block_joint":
|
||||||
|
return _FakeJoint(ring_qpos)
|
||||||
|
if name == "bar_block_joint":
|
||||||
|
return _FakeJoint(bar_qpos)
|
||||||
|
raise AssertionError(f"Unexpected joint name: {name}")
|
||||||
|
|
||||||
|
env_state = getter(_FakeData())
|
||||||
|
|
||||||
|
self.assertEqual(env_state.shape, (14,))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
env_state,
|
||||||
|
np.array(
|
||||||
|
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user