feat(sim): switch air insert task to socket peg

This commit is contained in:
Logic
2026-05-02 17:34:43 +08:00
parent 4c3646a3d5
commit 5c5cb299e9
16 changed files with 594 additions and 630 deletions

View File

@@ -1,6 +1,6 @@
<mujoco model="bi_diana_ring_bar"> <mujoco model="bi_diana_socket_peg">
<include file="./empty_world.xml" /> <include file="./empty_world.xml" />
<include file="./table_square.xml" /> <include file="./table_square.xml" />
<include file="./ring_bar_objects.xml" /> <include file="./socket_peg_objects.xml" />
<include file="./BiDianaMed_rethink.xml" /> <include file="./BiDianaMed_rethink.xml" />
</mujoco> </mujoco>

View File

@@ -1,28 +0,0 @@
<mujoco model="ring_bar_objects">
<worldbody>
<body name="ring_block" pos="-0.12 0.90 0.47">
<joint name="ring_block_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.03" diaginertia="0.001 0.001 0.001" />
<geom name="ring_block_north" type="box" pos="0 0.025 0" size="0.034 0.009 0.009"
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
friction="4 0.05 0.001" rgba="1 0 0 1" />
<geom name="ring_block_south" type="box" pos="0 -0.025 0" size="0.034 0.009 0.009"
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
friction="4 0.05 0.001" rgba="1 0 0 1" />
<geom name="ring_block_east" type="box" pos="0.025 0 0" size="0.009 0.016 0.009"
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
friction="4 0.05 0.001" rgba="1 0 0 1" />
<geom name="ring_block_west" type="box" pos="-0.025 0 0" size="0.009 0.016 0.009"
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
friction="4 0.05 0.001" rgba="1 0 0 1" />
</body>
<body name="bar_block" pos="0.12 0.90 0.47">
<joint name="bar_block_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.015" diaginertia="0.0005 0.0005 0.0005" />
<geom name="bar_block" type="box" pos="0 0 0" size="0.045 0.009 0.009"
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
friction="6 0.08 0.002" rgba="0 0.7 0.2 1" />
</body>
</worldbody>
</mujoco>

View File

@@ -0,0 +1,19 @@
<mujoco model="socket_peg_objects">
<worldbody>
<body name="peg" pos="0.12 0.90 0.46">
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
</body>
<body name="socket" pos="-0.12 0.90 0.472">
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
</body>
</worldbody>
</mujoco>

View File

@@ -7,7 +7,7 @@
<geom name="table" condim="4" contype="1" conaffinity="1" type="box" rgba="0.4 0.4 0.4 1" size="0.62 0.62 0.01" density="1500" friction="0.9 0.9 0.9"/> <geom name="table" condim="4" contype="1" conaffinity="1" type="box" rgba="0.4 0.4 0.4 1" size="0.62 0.62 0.01" density="1500" friction="0.9 0.9 0.9"/>
</body> </body>
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/> <camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/> <camera name="left_side" pos="-0.55 0.85 0.85" fovy="65" mode="targetbody" target="table"/>
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/> <camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
</worldbody> </worldbody>
</mujoco> </mujoco>

View File

@@ -92,12 +92,12 @@ class BiDianaMed(ArmBase):
return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0]) return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0])
class BiDianaMedRingBar(ArmBase): class BiDianaMedSocketPeg(ArmBase):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
name="Bidiana_ring_bar", name="Bidiana_socket_peg",
urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf", urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf",
xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml", xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml",
gripper=None gripper=None
) )
self.left_arm = self.Arm(self, 'single', self.urdf_path) self.left_arm = self.Arm(self, 'single', self.urdf_path)

View File

@@ -5,16 +5,39 @@ from roboimi.demos.diana_policy import PolicyBase
class TestAirInsertPolicy(PolicyBase): class TestAirInsertPolicy(PolicyBase):
@staticmethod ACTION_OBJECT_Z_OFFSET = 0.078
def _action_xyz_for_object_center(object_center, ee_quat, object_offset_local): SOCKET_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64)
return ( PEG_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64)
np.asarray(object_center, dtype=np.float64) SOCKET_OUTER_GRASP_STRATEGY = "socket_outer"
- np.asarray(Quaternion(ee_quat).rotate(object_offset_local), dtype=np.float64) LEGACY_GRASP_STRATEGY = "legacy"
) SOCKET_HOLD_Z = 0.85
PEG_INSERT_START_OFFSET = np.array([0.105, 0.0, 0.0], dtype=np.float64)
INSERT_START_T = 650
INSERT_END_T = 700
LEFT_SOCKET_GRIPPER_CLOSED = -70
RIGHT_PEG_GRIPPER_CLOSED = -100
SOCKET_APPROACH_Z = 1.05
EPISODE_END_T = 1000
def __init__(self, inject_noise=False, grasp_strategy=SOCKET_OUTER_GRASP_STRATEGY):
super().__init__(inject_noise=inject_noise)
valid_strategies = {
self.SOCKET_OUTER_GRASP_STRATEGY,
self.LEGACY_GRASP_STRATEGY,
}
if grasp_strategy not in valid_strategies:
raise ValueError(
f"Unsupported air insert grasp_strategy={grasp_strategy!r}; "
f"expected one of {sorted(valid_strategies)}"
)
self.grasp_strategy = grasp_strategy
def generate_trajectory(self, task_state): def generate_trajectory(self, task_state):
ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64) return self._generate_socket_peg_trajectory(task_state)
bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64)
def _generate_socket_peg_trajectory(self, task_state):
socket_xyz = np.asarray(task_state["socket_pos"], dtype=np.float64)
peg_xyz = np.asarray(task_state["peg_pos"], dtype=np.float64)
init_mocap_pose_left = np.array( init_mocap_pose_left = np.array(
[ [
@@ -44,63 +67,137 @@ class TestAirInsertPolicy(PolicyBase):
left_init_quat = Quaternion(init_mocap_pose_left[3:]) left_init_quat = Quaternion(init_mocap_pose_left[3:])
right_init_quat = Quaternion(init_mocap_pose_right[3:]) right_init_quat = Quaternion(init_mocap_pose_right[3:])
object_offset_local = np.array([0.0, 0.0, -0.09], dtype=np.float64) left_pick_quat = (
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45)
left_hold_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=-90).elements ).elements
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements right_pick_quat = (
insert_quat_local = Quaternion([-0.50019721, 0.50020088, 0.49980484, 0.49979692]) right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45)
right_insert_quat = np.array( ).elements
(Quaternion(left_hold_quat) * insert_quat_local).elements,
socket_hold_action = np.array(
[socket_xyz[0] - 0.078, socket_xyz[1], self.SOCKET_HOLD_Z], dtype=np.float64
)
peg_init_xyz = peg_xyz + np.array(
[0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET + 0.01]
)
peg_lift_center = np.array(
[peg_xyz[0] + 0.078, socket_hold_action[1], self.SOCKET_HOLD_Z - 0.01],
dtype=np.float64,
)
# The front camera looks along +Y, so visual right-to-left insertion is
# world +X -> -X. With the socket XML in identity orientation, its
# tunnel axis is local/world X, so the peg approaches from +X and stops
# when its leading face reaches the socket's internal pin.
peg_insert_end_center = np.array(
[
socket_hold_action[0] + 0.078 * 2 + 0.04 + 0.06 - 0.01,
socket_hold_action[1],
self.SOCKET_HOLD_Z - 0.01,
],
dtype=np.float64, dtype=np.float64,
) )
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
ring_stabilize_center = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64)
ring_hold_center = meet_xyz + np.array([-0.10, 0.05, -0.16], dtype=np.float64)
bar_reorient_center = bar_xyz + np.array([0.0, 0.0, 0.16], dtype=np.float64)
bar_wait_center = ring_hold_center + np.array([0.05, -0.18, 0.0], dtype=np.float64)
bar_insert_start_center = ring_hold_center + np.array([0.0, -0.075, 0.0], dtype=np.float64)
bar_insert_end_center = ring_hold_center + np.array([0.0, 0.075, 0.0], dtype=np.float64)
left_stabilize_xyz = self._action_xyz_for_object_center(
ring_stabilize_center, left_pick_quat, object_offset_local
)
left_hold_xyz = self._action_xyz_for_object_center(
ring_hold_center, left_hold_quat, object_offset_local
)
right_reorient_xyz = self._action_xyz_for_object_center(
bar_reorient_center, right_insert_quat, object_offset_local
)
right_wait_xyz = self._action_xyz_for_object_center(
bar_wait_center, right_insert_quat, object_offset_local
)
right_insert_start_xyz = self._action_xyz_for_object_center(
bar_insert_start_center, right_insert_quat, object_offset_local
)
right_insert_end_xyz = self._action_xyz_for_object_center(
bar_insert_end_center, right_insert_quat, object_offset_local
)
self.left_trajectory = [ self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, {
{"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100}, "t": 1,
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, "xyz": init_mocap_pose_left[:3],
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, "quat": init_mocap_pose_left[3:],
{"t": 260, "xyz": self._action_xyz_for_object_center(ring_xyz + np.array([0.0, 0.0, 0.24]), left_pick_quat, object_offset_local), "quat": left_pick_quat, "gripper": -100}, "gripper": 100,
{"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100}, },
{"t": 460, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, {
{"t": 700, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, "t": 130,
"xyz": socket_xyz
+ np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]),
"quat": left_pick_quat,
"gripper": 100,
},
{
"t": 180,
"xyz": socket_xyz
+ np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]),
"quat": left_pick_quat,
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
},
{
"t": 450,
"xyz": socket_hold_action,
"quat": left_pick_quat,
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
},
{
"t": 750,
"xyz": socket_hold_action,
"quat": left_pick_quat,
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
},
{
"t": self.EPISODE_END_T,
"xyz": socket_hold_action,
"quat": left_pick_quat,
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
},
] ]
self.right_trajectory = [ self.right_trajectory = [
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100}, {
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100}, "t": 1,
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, "xyz": init_mocap_pose_right[:3],
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, "quat": init_mocap_pose_right[3:],
{"t": 240, "xyz": self._action_xyz_for_object_center(bar_xyz + np.array([0.0, 0.0, 0.12]), right_pick_quat, object_offset_local), "quat": right_pick_quat, "gripper": -100}, "gripper": 100,
{"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100}, },
{"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100}, {
{"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, "t": 80,
{"t": 690, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, "xyz": peg_init_xyz,
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, "quat": right_pick_quat,
"gripper": 100,
},
{
"t": 150,
"xyz": peg_init_xyz,
"quat": right_pick_quat,
"gripper": 100,
},
{
"t": 180,
"xyz": peg_init_xyz,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": 450,
"xyz": peg_init_xyz,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": 550,
"xyz": peg_lift_center,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": self.INSERT_START_T,
"xyz": peg_lift_center,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": self.INSERT_END_T,
"xyz": peg_insert_end_center,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": 750,
"xyz": peg_insert_end_center,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
{
"t": self.EPISODE_END_T,
"xyz": peg_insert_end_center,
"quat": right_pick_quat,
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
},
] ]

View File

@@ -5,7 +5,7 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy
from roboimi.demos.diana_policy import TestPickAndTransferPolicy from roboimi.demos.diana_policy import TestPickAndTransferPolicy
import cv2 import cv2
from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose from roboimi.utils.act_ex_utils import sample_air_insert_socket_peg_state, sample_transfer_pose
from roboimi.utils.constants import SIM_TASK_CONFIGS from roboimi.utils.constants import SIM_TASK_CONFIGS
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
@@ -17,16 +17,18 @@ DATASET_DIR = HOME_PATH + '/dataset'
def sample_task_state(task_name): def sample_task_state(task_name):
if task_name == 'sim_transfer': if task_name == 'sim_transfer':
return sample_transfer_pose() return sample_transfer_pose()
if task_name == 'sim_air_insert_ring_bar': if task_name == 'sim_air_insert_socket_peg':
return sample_air_insert_ring_bar_state() return sample_air_insert_socket_peg_state()
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
def make_policy(task_name, inject_noise=False): def make_policy(task_name, inject_noise=False, grasp_strategy=None):
if task_name == 'sim_transfer': if task_name == 'sim_transfer':
return TestPickAndTransferPolicy(inject_noise) return TestPickAndTransferPolicy(inject_noise)
if task_name == 'sim_air_insert_ring_bar': if task_name == 'sim_air_insert_socket_peg':
return TestAirInsertPolicy(inject_noise) if grasp_strategy is None:
return TestAirInsertPolicy(inject_noise)
return TestAirInsertPolicy(inject_noise, grasp_strategy=grasp_strategy)
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
@@ -37,9 +39,9 @@ def main(task_name='sim_transfer'):
inject_noise = False inject_noise = False
episode_len = task_cfg['episode_len'] episode_len = task_cfg['episode_len']
camera_names = ['angle', 'r_vis', 'top', 'front'] camera_names = ['left_side', 'r_vis', 'top', 'front']
image_size = (256, 256) image_size = (256, 256)
if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}: if task_name in {'sim_transfer', 'sim_air_insert_socket_peg'}:
print(task_name) print(task_name)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@@ -27,7 +27,7 @@ from einops import rearrange
from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils.act_ex_utils import ( from roboimi.utils.act_ex_utils import (
sample_air_insert_ring_bar_state, sample_air_insert_socket_peg_state,
sample_transfer_pose, sample_transfer_pose,
) )
from roboimi.vla.eval_utils import execute_policy_action from roboimi.vla.eval_utils import execute_policy_action
@@ -489,8 +489,8 @@ def _close_env(env):
def _sample_task_reset_state(task_name: str): def _sample_task_reset_state(task_name: str):
if task_name == 'sim_air_insert_ring_bar': if task_name == 'sim_air_insert_socket_peg':
return sample_air_insert_ring_bar_state() return sample_air_insert_socket_peg_state()
if 'sim_transfer' in task_name: if 'sim_transfer' in task_name:
return sample_transfer_pose() return sample_transfer_pose()
raise NotImplementedError(f'Unsupported eval task reset sampling: {task_name}') raise NotImplementedError(f'Unsupported eval task reset sampling: {task_name}')

View File

@@ -1,23 +1,19 @@
import copy as cp import copy as cp
import time import time
import mujoco as mj
import numpy as np import numpy as np
from roboimi.envs.double_base import DualDianaMed from roboimi.envs.double_base import DualDianaMed
from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
RING_JOINT_NAME = "ring_block_joint" SOCKET_JOINT_NAME = "blue_socket_joint"
BAR_JOINT_NAME = "bar_block_joint" PEG_JOINT_NAME = "red_peg_joint"
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat") REQUIRED_TASK_STATE_KEYS = ("socket_pos", "socket_quat", "peg_pos", "peg_quat")
RING_GEOM_NAMES = ( SOCKET_GEOM_NAMES = ("socket-1", "socket-2", "socket-3", "socket-4")
"ring_block_north", SOCKET_SUCCESS_GEOM_NAMES = ("pin",)
"ring_block_south", SOCKET_BODY_GEOM_NAMES = SOCKET_GEOM_NAMES + SOCKET_SUCCESS_GEOM_NAMES
"ring_block_east", PEG_GEOM_NAMES = ("red_peg",)
"ring_block_west",
)
BAR_GEOM_NAMES = ("bar_block",)
LEFT_GRIPPER_GEOM_NAMES = ( LEFT_GRIPPER_GEOM_NAMES = (
"l_finger_left", "l_finger_left",
"r_finger_left", "r_finger_left",
@@ -25,6 +21,8 @@ LEFT_GRIPPER_GEOM_NAMES = (
"r_fingertip_g0_left", "r_fingertip_g0_left",
"l_fingerpad_g0_left", "l_fingerpad_g0_left",
"r_fingerpad_g0_left", "r_fingerpad_g0_left",
"l_fingertip_g0_vis_left",
"r_fingertip_g0_vis_left",
) )
RIGHT_GRIPPER_GEOM_NAMES = ( RIGHT_GRIPPER_GEOM_NAMES = (
"l_finger_right", "l_finger_right",
@@ -33,12 +31,10 @@ RIGHT_GRIPPER_GEOM_NAMES = (
"r_fingertip_g0_right", "r_fingertip_g0_right",
"l_fingerpad_g0_right", "l_fingerpad_g0_right",
"r_fingerpad_g0_right", "r_fingerpad_g0_right",
"l_fingertip_g0_vis_right",
"r_fingertip_g0_vis_right",
) )
TABLE_GEOM_NAME = "table" TABLE_GEOM_NAME = "table"
RING_APERTURE_HALF_WIDTH = 0.016
RING_HALF_THICKNESS = 0.009
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0
def _set_free_joint_pose(joint, position, quat): def _set_free_joint_pose(joint, position, quat):
@@ -46,29 +42,29 @@ def _set_free_joint_pose(joint, position, quat):
joint.qpos[3:7] = np.asarray(quat, dtype=np.float64) joint.qpos[3:7] = np.asarray(quat, dtype=np.float64)
def set_ring_bar_task_state(mj_data, task_state): def set_socket_peg_task_state(mj_data, task_state):
if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS: if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS:
raise ValueError( raise ValueError(
"task_state must be an ordered dict-like mapping with keys " "task_state must be an ordered dict-like mapping with keys "
"ring_pos, ring_quat, bar_pos, bar_quat" "socket_pos, socket_quat, peg_pos, peg_quat"
) )
_set_free_joint_pose( _set_free_joint_pose(
mj_data.joint(RING_JOINT_NAME), mj_data.joint(SOCKET_JOINT_NAME),
task_state["ring_pos"], task_state["socket_pos"],
task_state["ring_quat"], task_state["socket_quat"],
) )
_set_free_joint_pose( _set_free_joint_pose(
mj_data.joint(BAR_JOINT_NAME), mj_data.joint(PEG_JOINT_NAME),
task_state["bar_pos"], task_state["peg_pos"],
task_state["bar_quat"], task_state["peg_quat"],
) )
def get_ring_bar_env_state(mj_data): def get_socket_peg_env_state(mj_data):
ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64)) socket_qpos = cp.deepcopy(np.asarray(mj_data.joint(SOCKET_JOINT_NAME).qpos[:7], dtype=np.float64))
bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64)) peg_qpos = cp.deepcopy(np.asarray(mj_data.joint(PEG_JOINT_NAME).qpos[:7], dtype=np.float64))
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64) return np.concatenate([socket_qpos, peg_qpos], dtype=np.float64)
def _normalize_contact_pairs(contact_pairs): def _normalize_contact_pairs(contact_pairs):
@@ -87,91 +83,29 @@ def _object_is_airborne(contact_set, object_geom_names):
return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,)) return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,))
def _quat_to_rotation_matrix(quat): def peg_inserted_into_socket(contact_pairs):
quat = np.asarray(quat, dtype=np.float64) contact_set = _normalize_contact_pairs(contact_pairs)
quat /= np.linalg.norm(quat) return frozenset((PEG_GEOM_NAMES[0], SOCKET_SUCCESS_GEOM_NAMES[0])) in contact_set
w, x, y, z = quat
return np.array(
[
[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)],
[2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)],
[2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)],
],
dtype=np.float64,
)
def _quat_multiply(lhs, rhs): def compute_air_insert_reward(contact_pairs, env_state=None):
lhs = np.asarray(lhs, dtype=np.float64) del env_state # kept for API compatibility with rollout/eval code paths
rhs = np.asarray(rhs, dtype=np.float64)
w1, x1, y1, z1 = lhs
w2, x2, y2, z2 = rhs
return np.array(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
dtype=np.float64,
)
def _quat_inverse(quat):
quat = np.asarray(quat, dtype=np.float64)
norm_sq = float(np.dot(quat, quat))
return np.array([quat[0], -quat[1], -quat[2], -quat[3]], dtype=np.float64) / norm_sq
def _split_env_state(env_state):
env_state = np.asarray(env_state, dtype=np.float64)
if env_state.shape != (14,):
raise ValueError(f"env_state must have shape (14,), got {env_state.shape}")
return (
env_state[:3],
env_state[3:7],
env_state[7:10],
env_state[10:14],
)
def bar_fully_inserted_through_ring(env_state):
ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state)
ring_rot = _quat_to_rotation_matrix(ring_quat)
bar_rot = _quat_to_rotation_matrix(bar_quat)
bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos)
bar_rot_in_ring = ring_rot.T @ bar_rot
projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES
spans_ring_thickness = (
bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS
and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS
)
fits_aperture = (
abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH
and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH
)
return bool(spans_ring_thickness and fits_aperture)
def compute_air_insert_reward(contact_pairs, env_state):
contact_set = _normalize_contact_pairs(contact_pairs) contact_set = _normalize_contact_pairs(contact_pairs)
reward = 0 reward = 0
if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES): if _has_any_object_contact(contact_set, SOCKET_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES):
reward += 1 reward += 1
if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES): if _has_any_object_contact(contact_set, PEG_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES):
reward += 1 reward += 1
ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES) socket_airborne = _object_is_airborne(contact_set, SOCKET_BODY_GEOM_NAMES)
bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES) peg_airborne = _object_is_airborne(contact_set, PEG_GEOM_NAMES)
if ring_airborne: if socket_airborne:
reward += 1 reward += 1
if bar_airborne: if peg_airborne:
reward += 1 reward += 1
if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state): if socket_airborne and peg_airborne and peg_inserted_into_socket(contact_pairs):
reward += 1 reward += 1
return reward return reward
@@ -181,33 +115,19 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.max_reward = 5 self.max_reward = 5
self._scripted_ring_grasped = False
self._scripted_bar_grasped = False
self._scripted_ring_pos_offset_local = None
self._scripted_bar_pos_offset_local = None
self._scripted_ring_quat_offset = None
self._scripted_bar_quat_offset = None
self._air_insert_step_count = 0
def reset(self, task_state): def reset(self, task_state):
self._scripted_ring_grasped = False set_socket_peg_task_state(self.mj_data, task_state)
self._scripted_bar_grasped = False
self._scripted_ring_pos_offset_local = None
self._scripted_bar_pos_offset_local = None
self._scripted_ring_quat_offset = None
self._scripted_bar_quat_offset = None
self._air_insert_step_count = 0
set_ring_bar_task_state(self.mj_data, task_state)
DualDianaMed.reset(self) DualDianaMed.reset(self)
self.top = None self.top = None
self.angle = None self.left_side = None
self.r_vis = None self.r_vis = None
self.front = None self.front = None
self.cam_flage = True self.cam_flage = True
while self.cam_flage: while self.cam_flage:
if ( if (
type(self.top) == type(None) type(self.top) == type(None)
or type(self.angle) == type(None) or type(self.left_side) == type(None)
or type(self.r_vis) == type(None) or type(self.r_vis) == type(None)
or type(self.front) == type(None) or type(self.front) == type(None)
): ):
@@ -217,76 +137,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
def step(self, action=np.zeros(16)): def step(self, action=np.zeros(16)):
super().step(action) super().step(action)
self._update_scripted_grasped_objects(action)
self.rew = self._get_reward() self.rew = self._get_reward()
self.obs = self._get_obs() self.obs = self._get_obs()
self._air_insert_step_count += 1
def _update_scripted_grasped_objects(self, action):
if (
action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD
and self._air_insert_step_count >= 180
and not self._scripted_ring_grasped
):
self._scripted_ring_grasped = True
self._attach_scripted_object(
object_joint_name=RING_JOINT_NAME,
ee_pos=action[:3],
ee_quat=action[3:7],
pos_attr="_scripted_ring_pos_offset_local",
quat_attr="_scripted_ring_quat_offset",
)
if (
action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD
and self._air_insert_step_count >= 180
and not self._scripted_bar_grasped
):
self._scripted_bar_grasped = True
self._attach_scripted_object(
object_joint_name=BAR_JOINT_NAME,
ee_pos=action[7:10],
ee_quat=action[10:14],
pos_attr="_scripted_bar_pos_offset_local",
quat_attr="_scripted_bar_quat_offset",
)
if self._scripted_ring_grasped:
self._update_scripted_object_pose(
object_joint_name=RING_JOINT_NAME,
ee_pos=action[:3],
ee_quat=action[3:7],
pos_offset_local=self._scripted_ring_pos_offset_local,
quat_offset=self._scripted_ring_quat_offset,
)
if self._scripted_bar_grasped:
self._update_scripted_object_pose(
object_joint_name=BAR_JOINT_NAME,
ee_pos=action[7:10],
ee_quat=action[10:14],
pos_offset_local=self._scripted_bar_pos_offset_local,
quat_offset=self._scripted_bar_quat_offset,
)
if self._scripted_ring_grasped or self._scripted_bar_grasped:
mj.mj_forward(self.mj_model, self.mj_data)
def _attach_scripted_object(self, object_joint_name, ee_pos, ee_quat, pos_attr, quat_attr):
ee_pos = np.asarray(ee_pos, dtype=np.float64)
ee_quat = np.asarray(ee_quat, dtype=np.float64)
object_qpos = np.asarray(self.mj_data.joint(object_joint_name).qpos[:7], dtype=np.float64)
ee_rot = _quat_to_rotation_matrix(ee_quat)
setattr(self, pos_attr, ee_rot.T @ (object_qpos[:3] - ee_pos))
setattr(self, quat_attr, _quat_multiply(_quat_inverse(ee_quat), object_qpos[3:7]))
def _update_scripted_object_pose(self, object_joint_name, ee_pos, ee_quat, pos_offset_local, quat_offset):
ee_pos = np.asarray(ee_pos, dtype=np.float64)
ee_quat = np.asarray(ee_quat, dtype=np.float64)
ee_rot = _quat_to_rotation_matrix(ee_quat)
object_pos = ee_pos + ee_rot @ np.asarray(pos_offset_local, dtype=np.float64)
object_quat = _quat_multiply(ee_quat, quat_offset)
_set_free_joint_pose(self.mj_data.joint(object_joint_name), object_pos, object_quat)
def get_env_state(self): def get_env_state(self):
return get_ring_bar_env_state(self.mj_data) return get_socket_peg_env_state(self.mj_data)
def _get_reward(self): def _get_reward(self):
contact_pairs = [] contact_pairs = []
@@ -296,8 +151,4 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
contact_pairs.append( contact_pairs.append(
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2)) (self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
) )
if self._scripted_ring_grasped:
contact_pairs.append(("ring_block_south", "l_fingertip_g0_left"))
if self._scripted_bar_grasped:
contact_pairs.append(("bar_block", "r_fingertip_g0_right"))
return compute_air_insert_reward(contact_pairs, self.get_env_state()) return compute_air_insert_reward(contact_pairs, self.get_env_state())

View File

@@ -52,7 +52,7 @@ class DualDianaMed(MujocoEnv):
self.r_vis = None self.r_vis = None
self.l_vis = None self.l_vis = None
self.top = None self.top = None
self.angle = None self.left_side = None
self.front = None self.front = None
self.obs = None self.obs = None
@@ -166,7 +166,7 @@ class DualDianaMed(MujocoEnv):
obs['action'] = self.compute_qpos obs['action'] = self.compute_qpos
obs['images'] = dict() obs['images'] = dict()
obs['images']['top'] = self.top obs['images']['top'] = self.top
obs['images']['angle'] = self.angle obs['images']['left_side'] = self.left_side
obs['images']['r_vis'] = self.r_vis obs['images']['r_vis'] = self.r_vis
obs['images']['l_vis'] = self.l_vis obs['images']['l_vis'] = self.l_vis
obs['images']['front'] = self.front obs['images']['front'] = self.front
@@ -176,7 +176,7 @@ class DualDianaMed(MujocoEnv):
obs = collections.OrderedDict() obs = collections.OrderedDict()
obs['images'] = dict() obs['images'] = dict()
obs['images']['top'] = self.top obs['images']['top'] = self.top
obs['images']['angle'] = self.angle obs['images']['left_side'] = self.left_side
obs['images']['r_vis'] = self.r_vis obs['images']['r_vis'] = self.r_vis
obs['images']['l_vis'] = self.l_vis obs['images']['l_vis'] = self.l_vis
obs['images']['front'] = self.front obs['images']['front'] = self.front
@@ -199,8 +199,8 @@ class DualDianaMed(MujocoEnv):
def cam_view(self): def cam_view(self):
if self.cam == 'top': if self.cam == 'top':
return self.top return self.top
elif self.cam == 'angle': elif self.cam == 'left_side':
return self.angle return self.left_side
elif self.cam == 'r_vis': elif self.cam == 'r_vis':
return self.r_vis return self.r_vis
elif self.cam == 'l_vis': elif self.cam == 'l_vis':
@@ -226,9 +226,9 @@ class DualDianaMed(MujocoEnv):
img_renderer.update_scene(self.mj_data,camera="top") img_renderer.update_scene(self.mj_data,camera="top")
self.top = img_renderer.render() self.top = img_renderer.render()
self.top = self.top[:, :, ::-1] self.top = self.top[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="angle") img_renderer.update_scene(self.mj_data,camera="left_side")
self.angle = img_renderer.render() self.left_side = img_renderer.render()
self.angle = self.angle[:, :, ::-1] self.left_side = self.left_side[:, :, ::-1]
img_renderer.update_scene(self.mj_data,camera="front") img_renderer.update_scene(self.mj_data,camera="front")
self.front = img_renderer.render() self.front = img_renderer.render()
self.front = self.front[:, :, ::-1] self.front = self.front[:, :, ::-1]

View File

@@ -34,19 +34,19 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
is_interpolate=is_interpolate, is_interpolate=is_interpolate,
cam_view=cam_view cam_view=cam_view
) )
self.max_reward = 4 self.max_reward = 4
self.cam_start() self.cam_start()
def step(self,action=np.zeros(16)): def step(self,action=np.zeros(16)):
action_left = self.ik_solve(action[:3],action[3:7],self.arm_left) action_left = self.ik_solve(action[:3],action[3:7],self.arm_left)
action_right = self.ik_solve(action[7:10],action[10:14],self.arm_right) action_right = self.ik_solve(action[7:10],action[10:14],self.arm_right)
action = np.hstack((action_left,action_right,action[14:])) action = np.hstack((action_left,action_right,action[14:]))
super().step(action) super().step(action)
self.rew = self._get_reward() self.rew = self._get_reward()
def step_jnt(self,action): def step_jnt(self,action):
super().step(action) super().step(action)
@@ -63,8 +63,8 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
return Arm.kdl_solver.ikSolver(p_goal, mat_goal, Arm.arm_qpos) return Arm.kdl_solver.ikSolver(p_goal, mat_goal, Arm.arm_qpos)
def reset(self,box_pos): def reset(self,box_pos):
self.mj_data.joint('red_box_joint').qpos[0] = box_pos[0] self.mj_data.joint('red_box_joint').qpos[0] = box_pos[0]
self.mj_data.joint('red_box_joint').qpos[1] = box_pos[1] self.mj_data.joint('red_box_joint').qpos[1] = box_pos[1]
self.mj_data.joint('red_box_joint').qpos[2] = box_pos[2] self.mj_data.joint('red_box_joint').qpos[2] = box_pos[2]
self.mj_data.joint('red_box_joint').qpos[3] = 1.0 self.mj_data.joint('red_box_joint').qpos[3] = 1.0
@@ -73,22 +73,22 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
self.mj_data.joint('red_box_joint').qpos[6] = 0.0 self.mj_data.joint('red_box_joint').qpos[6] = 0.0
super().reset() super().reset()
self.top = None self.top = None
self.angle = None self.left_side = None
self.r_vis = None self.r_vis = None
self.front = None self.front = None
self.cam_flage = True self.cam_flage = True
t=0 t=0
while self.cam_flage: while self.cam_flage:
if(type(self.top)==type(None) if(type(self.top)==type(None)
or type(self.angle)==type(None) or type(self.left_side)==type(None)
or type(self.r_vis)==type(None) or type(self.r_vis)==type(None)
or type(self.front)==type(None)): or type(self.front)==type(None)):
time.sleep(0.001) time.sleep(0.001)
t+=1 t+=1
else: else:
self.cam_flage=False self.cam_flage=False
def preStep(self, action): def preStep(self, action):
if isinstance(action,np.ndarray) and len(action)==16: if isinstance(action,np.ndarray) and len(action)==16:
@@ -101,7 +101,7 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
for i in range(3): for i in range(3):
box_pose[i] = cp.deepcopy(self.mj_data.joint('red_box_joint').qpos[i]) box_pose[i] = cp.deepcopy(self.mj_data.joint('red_box_joint').qpos[i])
return box_pose return box_pose
def _get_reward(self): def _get_reward(self):
all_contact_pairs = [] all_contact_pairs = []
@@ -124,26 +124,26 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
reward = 0 reward = 0
if touch_right_gripper and not touch_table: if touch_right_gripper and not touch_table:
reward = 1 reward = 1
if touch_right_gripper and not box_touch_table: if touch_right_gripper and not box_touch_table:
reward = 2 reward = 2
if touch_left_gripper: # attempted transfer if touch_left_gripper: # attempted transfer
reward = 3 reward = 3
if touch_left_gripper and not box_touch_table: # successful transfer if touch_left_gripper and not box_touch_table: # successful transfer
reward = 4 reward = 4
return reward return reward
def make_sim_env(task_name, headless=False): def make_sim_env(task_name, headless=False):
if task_name == 'sim_air_insert_ring_bar': if task_name == 'sim_air_insert_socket_peg':
from roboimi.assets.robots.diana_med import BiDianaMedRingBar from roboimi.assets.robots.diana_med import BiDianaMedSocketPeg
from roboimi.envs.double_air_insert_env import DualDianaMed_Air_Insert from roboimi.envs.double_air_insert_env import DualDianaMed_Air_Insert
env = DualDianaMed_Air_Insert( env = DualDianaMed_Air_Insert(
robot=BiDianaMedRingBar(), robot=BiDianaMedSocketPeg(),
is_render=not headless, is_render=not headless,
control_freq=30, control_freq=30,
is_interpolate=True, is_interpolate=True,
cam_view='angle' cam_view='left_side'
) )
return env return env
if 'sim_transfer' in task_name: if 'sim_transfer' in task_name:
@@ -153,7 +153,7 @@ def make_sim_env(task_name, headless=False):
is_render=not headless, is_render=not headless,
control_freq=30, control_freq=30,
is_interpolate=True, is_interpolate=True,
cam_view='angle' cam_view='left_side'
) )
return env return env
else: else:
@@ -179,4 +179,4 @@ if __name__ == "__main__":
env.step(action) env.step(action)
if env.is_render: if env.is_render:
env.render() env.render()

View File

@@ -39,19 +39,20 @@ def sample_transfer_pose():
return box_position return box_position
def sample_air_insert_ring_bar_state(): def sample_air_insert_socket_peg_state():
ring_position = np.random.uniform( socket_position = np.random.uniform(
low=np.array([-0.20, 0.70, 0.47], dtype=np.float32), low=np.array([-0.14, 0.89, 0.472], dtype=np.float32),
high=np.array([-0.05, 1.00, 0.47], dtype=np.float32), high=np.array([-0.10, 0.94, 0.472], dtype=np.float32),
) )
bar_position = np.random.uniform( peg_position = np.random.uniform(
low=np.array([0.05, 0.70, 0.47], dtype=np.float32), low=np.array([0.10, 0.85, 0.46], dtype=np.float32),
high=np.array([0.20, 1.00, 0.47], dtype=np.float32), high=np.array([0.16, 0.94, 0.46], dtype=np.float32),
) )
fixed_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) socket_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
peg_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
return { return {
"ring_pos": ring_position.astype(np.float32, copy=False), "socket_pos": socket_position.astype(np.float32, copy=False),
"ring_quat": fixed_quat.copy(), "socket_quat": socket_quat,
"bar_pos": bar_position.astype(np.float32, copy=False), "peg_pos": peg_position.astype(np.float32, copy=False),
"bar_quat": fixed_quat.copy(), "peg_quat": peg_quat,
} }

View File

@@ -23,10 +23,10 @@ SIM_TASK_CONFIGS = {
'camera_names': ['top','r_vis','front'], 'camera_names': ['top','r_vis','front'],
'xml_dir': HOME_PATH + '/assets' 'xml_dir': HOME_PATH + '/assets'
}, },
'sim_air_insert_ring_bar': { 'sim_air_insert_socket_peg': {
'dataset_dir': DATASET_DIR + '/sim_air_insert_ring_bar', 'dataset_dir': DATASET_DIR + '/sim_air_insert_socket_peg',
'num_episodes': 20, 'num_episodes': 20,
'episode_len': 700, 'episode_len': 1000,
'camera_names': ['top', 'r_vis', 'front'], 'camera_names': ['top', 'r_vis', 'front'],
'xml_dir': HOME_PATH + '/assets' 'xml_dir': HOME_PATH + '/assets'
}, },
@@ -59,13 +59,3 @@ PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) /
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2

View File

@@ -1,6 +1,9 @@
import importlib import importlib
import inspect
import pathlib
import unittest import unittest
from unittest import mock from unittest import mock
import xml.etree.ElementTree as ET
import numpy as np import numpy as np
@@ -9,83 +12,80 @@ from roboimi.utils import act_ex_utils
from roboimi.utils.constants import SIM_TASK_CONFIGS from roboimi.utils.constants import SIM_TASK_CONFIGS
class AirInsertTaskRegistrationTest(unittest.TestCase): TASK_NAME = "sim_air_insert_socket_peg"
def test_sim_task_configs_registers_air_insert_ring_bar(self):
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) class AirInsertTaskRegistrationTest(unittest.TestCase):
def test_sim_task_configs_registers_air_insert_socket_peg(self):
self.assertIn(TASK_NAME, SIM_TASK_CONFIGS)
self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
self.assertGreaterEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 1000)
self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg"))
def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
self.assertIsNotNone( self.assertIsNotNone(
sampler, sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()", "Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()",
)
self.assertFalse(
hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"),
"air insert sampler should use socket/peg naming after the task rename",
) )
task_state = sampler() task_state = sampler()
self.assertEqual( self.assertEqual(
list(task_state.keys()), list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"], ["socket_pos", "socket_quat", "peg_pos", "peg_quat"],
) )
self.assertEqual(task_state["ring_pos"].shape, (3,)) self.assertEqual(task_state["socket_pos"].shape, (3,))
self.assertEqual(task_state["ring_quat"].shape, (4,)) self.assertEqual(task_state["socket_quat"].shape, (4,))
self.assertEqual(task_state["bar_pos"].shape, (3,)) self.assertEqual(task_state["peg_pos"].shape, (3,))
self.assertEqual(task_state["bar_quat"].shape, (4,)) self.assertEqual(task_state["peg_quat"].shape, (4,))
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self): def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
self.assertIsNotNone( self.assertIsNotNone(sampler)
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
task_state = sampler() task_state = sampler()
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0])) np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0])) np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20) self.assertGreaterEqual(task_state["socket_pos"][0], -0.20)
self.assertLessEqual(task_state["ring_pos"][0], -0.05) self.assertLessEqual(task_state["socket_pos"][0], -0.05)
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70) self.assertGreaterEqual(task_state["socket_pos"][1], 0.70)
self.assertLessEqual(task_state["ring_pos"][1], 1.00) self.assertLessEqual(task_state["socket_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47) self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472)
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05) self.assertGreaterEqual(task_state["peg_pos"][0], 0.05)
self.assertLessEqual(task_state["bar_pos"][0], 0.20) self.assertLessEqual(task_state["peg_pos"][0], 0.20)
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70) self.assertGreaterEqual(task_state["peg_pos"][1], 0.70)
self.assertLessEqual(task_state["bar_pos"][1], 1.00) self.assertLessEqual(task_state["peg_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47) self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46)
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
try:
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
except Exception as exc:
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone( self.assertIsNotNone(air_insert_cls)
air_insert_cls,
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
)
diana_med = importlib.import_module("roboimi.assets.robots.diana_med") diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None) socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None)
self.assertIsNotNone( self.assertIsNotNone(
ring_bar_robot_cls, socket_peg_robot_cls,
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar", "Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg",
) )
fake_env = object() fake_env = object()
with mock.patch.object( with mock.patch.object(
diana_med, diana_med,
"BiDianaMedRingBar", "BiDianaMedSocketPeg",
return_value="robot", return_value="robot",
), mock.patch.object( ), mock.patch.object(
air_insert_env, air_insert_env,
"DualDianaMed_Air_Insert", "DualDianaMed_Air_Insert",
return_value=fake_env, return_value=fake_env,
) as env_cls: ) as env_cls:
try: env = make_sim_env(TASK_NAME, headless=True)
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
except Exception as exc:
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
self.assertIs(env, fake_env) self.assertIs(env, fake_env)
env_cls.assert_called_once_with( env_cls.assert_called_once_with(
@@ -93,21 +93,36 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
is_render=False, is_render=False,
control_freq=30, control_freq=30,
is_interpolate=True, is_interpolate=True,
cam_view="angle", cam_view="left_side",
) )
def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self):
xml_path = (
pathlib.Path(__file__).resolve().parents[1]
/ "roboimi/assets/models/manipulators/DianaMed/table_square.xml"
)
root = ET.parse(xml_path).getroot()
cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")}
self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera")
self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera")
left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ")
self.assertLess(float(left_side_pos[0]), 0.0)
self.assertEqual(cameras["left_side"].get("mode"), "targetbody")
self.assertEqual(cameras["left_side"].get("target"), "table")
class AirInsertResetAndStateHelpersTest(unittest.TestCase): class AirInsertResetAndStateHelpersTest(unittest.TestCase):
def test_set_ring_bar_task_state_writes_free_joint_qpos(self): def test_set_socket_peg_task_state_writes_free_joint_qpos(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
setter = getattr(air_insert_env, "set_ring_bar_task_state", None) setter = getattr(air_insert_env, "set_socket_peg_task_state", None)
self.assertIsNotNone( self.assertIsNotNone(
setter, setter,
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state", "Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state",
) )
ring_qpos = np.zeros(7, dtype=np.float64) socket_qpos = np.zeros(7, dtype=np.float64)
bar_qpos = np.zeros(7, dtype=np.float64) peg_qpos = np.zeros(7, dtype=np.float64)
class _FakeJoint: class _FakeJoint:
def __init__(self, qpos): def __init__(self, qpos):
@@ -115,40 +130,40 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
class _FakeData: class _FakeData:
def joint(self, name): def joint(self, name):
if name == "ring_block_joint": if name == "blue_socket_joint":
return _FakeJoint(ring_qpos) return _FakeJoint(socket_qpos)
if name == "bar_block_joint": if name == "red_peg_joint":
return _FakeJoint(bar_qpos) return _FakeJoint(peg_qpos)
raise AssertionError(f"Unexpected joint name: {name}") raise AssertionError(f"Unexpected joint name: {name}")
task_state = { task_state = {
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64), "socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64), "peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
} }
setter(_FakeData(), task_state) setter(_FakeData(), task_state)
np.testing.assert_array_equal( np.testing.assert_array_equal(
ring_qpos, socket_qpos,
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
) )
np.testing.assert_array_equal( np.testing.assert_array_equal(
bar_qpos, peg_qpos,
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
) )
def test_get_ring_bar_env_state_returns_stable_14d_vector(self): def test_get_socket_peg_env_state_returns_stable_14d_vector(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
getter = getattr(air_insert_env, "get_ring_bar_env_state", None) getter = getattr(air_insert_env, "get_socket_peg_env_state", None)
self.assertIsNotNone( self.assertIsNotNone(
getter, getter,
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state", "Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state",
) )
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
class _FakeJoint: class _FakeJoint:
def __init__(self, qpos): def __init__(self, qpos):
@@ -156,10 +171,10 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
class _FakeData: class _FakeData:
def joint(self, name): def joint(self, name):
if name == "ring_block_joint": if name == "blue_socket_joint":
return _FakeJoint(ring_qpos) return _FakeJoint(socket_qpos)
if name == "bar_block_joint": if name == "red_peg_joint":
return _FakeJoint(bar_qpos) return _FakeJoint(peg_qpos)
raise AssertionError(f"Unexpected joint name: {name}") raise AssertionError(f"Unexpected joint name: {name}")
env_state = getter(_FakeData()) env_state = getter(_FakeData())
@@ -168,38 +183,78 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
np.testing.assert_array_equal( np.testing.assert_array_equal(
env_state, env_state,
np.array( np.array(
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], [-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0],
dtype=np.float64, dtype=np.float64,
), ),
) )
def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone(env_cls)
source = inspect.getsource(env_cls)
self.assertNotIn("_update_scripted_grasped_objects", source)
self.assertNotIn("_scripted_", source)
self.assertNotIn("_stabilize_ring_grasp", source)
self.assertNotIn("_ring_grasp_locked", source)
get_reward_source = inspect.getsource(env_cls._get_reward)
self.assertNotIn("ring_block", get_reward_source)
self.assertNotIn("bar_block", get_reward_source)
def test_socket_peg_xml_defines_active_socket_and_peg_objects(self):
asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed"
xml_path = asset_dir / "socket_peg_objects.xml"
self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml")
self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed")
root = ET.parse(xml_path).getroot()
body_names = {body.attrib.get("name") for body in root.findall(".//body")}
geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")}
joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")}
self.assertIn("socket", body_names)
self.assertIn("peg", body_names)
self.assertNotIn("ring_block", body_names)
self.assertNotIn("bar_block", body_names)
self.assertIn("blue_socket_joint", joint_names)
self.assertIn("red_peg_joint", joint_names)
for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"):
self.assertIn(geom_name, geom_names)
def test_socket_peg_wrapper_includes_socket_peg_objects(self):
xml_path = (
pathlib.Path(__file__).resolve().parents[1]
/ "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml"
)
self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name")
root = ET.parse(xml_path).getroot()
includes = [include.attrib.get("file") for include in root.findall(".//include")]
self.assertIn("./socket_peg_objects.xml", includes)
self.assertNotIn("./ring_bar_objects.xml", includes)
class AirInsertRewardAndSuccessTest(unittest.TestCase): class AirInsertRewardAndSuccessTest(unittest.TestCase):
@staticmethod @staticmethod
def _make_env_state( def _make_env_state(
ring_pos=(0.0, 0.0, 0.50), socket_pos=(0.0, 0.0, 0.472),
ring_quat=(1.0, 0.0, 0.0, 0.0), socket_quat=(1.0, 0.0, 0.0, 0.0),
bar_pos=(0.0, 0.0, 0.50), peg_pos=(0.0, 0.0, 0.46),
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0), peg_quat=(1.0, 0.0, 0.0, 0.0),
): ):
return np.array( return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64)
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
dtype=np.float64,
)
def test_compute_air_insert_reward_counts_left_contact_stage(self): def test_compute_air_insert_reward_counts_left_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
self.assertIsNotNone( self.assertIsNotNone(reward_fn)
reward_fn,
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
)
reward = reward_fn( reward = reward_fn(
contact_pairs=[ contact_pairs=[
("ring_block_north", "l_finger_left"), ("socket-1", "l_finger_left"),
("ring_block_north", "table"), ("socket-1", "table"),
("bar_block", "table"), ("red_peg", "table"),
], ],
env_state=self._make_env_state(), env_state=self._make_env_state(),
) )
@@ -212,10 +267,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn( reward = reward_fn(
contact_pairs=[ contact_pairs=[
("ring_block_north", "l_finger_left"), ("socket-1", "l_finger_left"),
("bar_block", "l_finger_right"), ("red_peg", "l_finger_right"),
("ring_block_north", "table"), ("socket-1", "table"),
("bar_block", "table"), ("red_peg", "table"),
], ],
env_state=self._make_env_state(), env_state=self._make_env_state(),
) )
@@ -228,47 +283,43 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn( reward = reward_fn(
contact_pairs=[ contact_pairs=[
("ring_block_north", "l_finger_left"), ("socket-1", "l_finger_left"),
("bar_block", "l_finger_right"), ("red_peg", "l_finger_right"),
], ],
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)), env_state=self._make_env_state(),
) )
self.assertEqual(reward, 4) self.assertEqual(reward, 4)
def test_bar_fully_inserted_through_ring_accepts_true_positive(self): def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("socket-3", "r_fingertip_g0_vis_left"),
("red_peg", "l_fingertip_g0_vis_right"),
],
env_state=self._make_env_state(),
)
self.assertEqual(
reward,
4,
"visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts",
)
def test_peg_inserted_into_socket_uses_pin_contact(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None)
self.assertIsNotNone( self.assertIsNotNone(
success_fn, success_fn,
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring", "Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket",
) )
self.assertTrue( self.assertTrue(success_fn([("red_peg", "pin")]))
success_fn( self.assertTrue(success_fn([("pin", "red_peg")]))
self._make_env_state(), self.assertFalse(success_fn([("red_peg", "socket-1")]))
)
)
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
)
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
)
)
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self): def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
@@ -276,9 +327,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn( reward = reward_fn(
contact_pairs=[ contact_pairs=[
("ring_block_north", "l_finger_left"), ("socket-1", "l_finger_left"),
("bar_block", "l_finger_right"), ("red_peg", "l_finger_right"),
("ring_block_north", "table"), ("socket-1", "table"),
("red_peg", "pin"),
], ],
env_state=self._make_env_state(), env_state=self._make_env_state(),
) )
@@ -291,8 +343,9 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn( reward = reward_fn(
contact_pairs=[ contact_pairs=[
("ring_block_north", "l_finger_left"), ("socket-1", "l_finger_left"),
("bar_block", "l_finger_right"), ("red_peg", "l_finger_right"),
("red_peg", "pin"),
], ],
env_state=self._make_env_state(), env_state=self._make_env_state(),
) )
@@ -301,41 +354,129 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
class AirInsertPolicyAndSmokeTest(unittest.TestCase): class AirInsertPolicyAndSmokeTest(unittest.TestCase):
@staticmethod
def _canonical_task_state():
return {
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
def test_air_insert_policy_emits_valid_16d_action(self): def test_air_insert_policy_emits_valid_16d_action(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone( self.assertIsNotNone(policy_cls)
policy_cls,
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
)
task_state = act_ex_utils.sample_air_insert_ring_bar_state() task_state = act_ex_utils.sample_air_insert_socket_peg_state()
policy = policy_cls(inject_noise=False) policy = policy_cls(inject_noise=False)
action = policy.predict(task_state, 0) action = policy.predict(task_state, 0)
self.assertEqual(action.shape, (16,)) self.assertEqual(action.shape, (16,))
np.testing.assert_array_equal(action[-2:], np.array([100, 100])) np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self): def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = self._canonical_task_state()
policy = policy_cls(inject_noise=False)
policy.generate_trajectory(task_state)
start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T)
end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T)
self.assertLess(
end_waypoint["xyz"][0],
start_waypoint["xyz"][0] - 0.10,
"front-view right-to-left peg insertion should decrease world x substantially",
)
self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02)
expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168)
self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02)
self.assertGreater(float(start_waypoint["xyz"][2]), 0.70)
def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
policy = policy_cls(inject_noise=False)
policy.generate_trajectory(task_state)
left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180)
right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180)
action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11)
expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset])
expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01])
np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6)
np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6)
self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket")
self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg")
def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
base_state = {
"socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
shifted_state = dict(base_state)
shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32)
base_policy = policy_cls(inject_noise=False)
base_policy.generate_trajectory(base_state)
shifted_policy = policy_cls(inject_noise=False)
shifted_policy.generate_trajectory(shifted_state)
base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450)
shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450)
np.testing.assert_allclose(
base_hold["xyz"][:2],
base_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
atol=1e-6,
)
np.testing.assert_allclose(
shifted_hold["xyz"][:2],
shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
atol=1e-6,
)
def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = self._canonical_task_state()
policy = policy_cls(inject_noise=False)
for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]):
action = policy.predict(task_state, step)
self.assertEqual(action.shape, (16,))
def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self):
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes") rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
sampler_fn = getattr(rollout_module, "sample_task_state", None) sampler_fn = getattr(rollout_module, "sample_task_state", None)
policy_factory = getattr(rollout_module, "make_policy", None) policy_factory = getattr(rollout_module, "make_policy", None)
self.assertIsNotNone( self.assertIsNotNone(sampler_fn)
sampler_fn, self.assertIsNotNone(policy_factory)
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
)
self.assertIsNotNone(
policy_factory,
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
)
task_state = sampler_fn("sim_air_insert_ring_bar") task_state = sampler_fn(TASK_NAME)
self.assertEqual( self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"])
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False) policy = policy_factory(TASK_NAME, inject_noise=False)
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy") self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self): def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
@@ -343,8 +484,8 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls) self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state() task_state = act_ex_utils.sample_air_insert_socket_peg_state()
env = make_sim_env("sim_air_insert_ring_bar", headless=True) env = make_sim_env(TASK_NAME, headless=True)
policy = policy_cls(inject_noise=False) policy = policy_cls(inject_noise=False)
try: try:
@@ -363,115 +504,6 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
if viewer is not None: if viewer is not None:
viewer.close() viewer.close()
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
def is_cross_arm_pair(a, b):
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
try:
env.reset(task_state)
for step in range(460):
action = policy.predict(task_state, step)
env.step(action)
pairs = []
for i in range(env.mj_data.ncon):
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
pairs.append((geom1, geom2))
self.assertFalse(
pairs,
f"cross-arm contact detected at step {step}: {pairs[:5]}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
for step in range(400):
action = policy.predict(task_state, step)
env.step(action)
ring_z = float(env.get_env_state()[2])
self.assertGreater(
ring_z,
0.55,
f"ring dropped before hold phase completed, final z={ring_z:.4f}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_reaches_max_reward_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
max_reward = float("-inf")
try:
env.reset(task_state)
for step in range(700):
action = policy.predict(task_state, step)
env.step(action)
max_reward = max(max_reward, float(env.rew))
self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}")
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -114,7 +114,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
is_render=False, is_render=False,
control_freq=30, control_freq=30,
is_interpolate=True, is_interpolate=True,
cam_view="angle", cam_view="left_side",
) )
def test_camera_viewer_headless_updates_images_without_gui_calls(self): def test_camera_viewer_headless_updates_images_without_gui_calls(self):
@@ -123,11 +123,11 @@ class EvalVLAHeadlessTest(unittest.TestCase):
env.mj_data = object() env.mj_data = object()
env.exit_flag = False env.exit_flag = False
env.is_render = False env.is_render = False
env.cam = "angle" env.cam = "left_side"
env.r_vis = None env.r_vis = None
env.l_vis = None env.l_vis = None
env.top = None env.top = None
env.angle = None env.left_side = None
env.front = None env.front = None
with mock.patch( with mock.patch(
@@ -144,7 +144,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertIsNotNone(env.r_vis) self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis) self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top) self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle) self.assertIsNotNone(env.left_side)
self.assertIsNotNone(env.front) self.assertIsNotNone(env.front)
def test_eval_main_headless_skips_render_and_still_executes_policy(self): def test_eval_main_headless_skips_render_and_still_executes_policy(self):
@@ -254,19 +254,19 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertAlmostEqual(summary["avg_reward"], 3.75) self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2) self.assertEqual(summary["num_episodes"], 2)
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self): def test_run_eval_uses_air_insert_sampler_for_socket_peg_task(self):
self.assertTrue( self.assertTrue(
hasattr(eval_vla, "sample_air_insert_ring_bar_state"), hasattr(eval_vla, "sample_air_insert_socket_peg_state"),
"Expected eval_vla to expose the new ring/bar reset sampler", "Expected eval_vla to expose the new socket/peg reset sampler",
) )
fake_env = _FakeEnv() fake_env = _FakeEnv()
fake_agent = _FakeAgent() fake_agent = _FakeAgent()
sampled_task_state = { sampled_task_state = {
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32), "socket_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32), "peg_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
} }
cfg = OmegaConf.create( cfg = OmegaConf.create(
{ {
@@ -276,7 +276,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
"num_episodes": 1, "num_episodes": 1,
"max_timesteps": 1, "max_timesteps": 1,
"device": "cpu", "device": "cpu",
"task_name": "sim_air_insert_ring_bar", "task_name": "sim_air_insert_socket_peg",
"camera_names": ["front"], "camera_names": ["front"],
"use_smoothing": False, "use_smoothing": False,
"smooth_alpha": 0.3, "smooth_alpha": 0.3,
@@ -296,12 +296,12 @@ class EvalVLAHeadlessTest(unittest.TestCase):
return_value=fake_env, return_value=fake_env,
) as make_env, mock.patch.object( ) as make_env, mock.patch.object(
eval_vla, eval_vla,
"sample_air_insert_ring_bar_state", "sample_air_insert_socket_peg_state",
return_value=sampled_task_state, return_value=sampled_task_state,
) as ring_bar_sampler, mock.patch.object( ) as socket_peg_sampler, mock.patch.object(
eval_vla, eval_vla,
"sample_transfer_pose", "sample_transfer_pose",
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"), side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_socket_peg"),
), mock.patch.object( ), mock.patch.object(
eval_vla, eval_vla,
"execute_policy_action", "execute_policy_action",
@@ -312,8 +312,8 @@ class EvalVLAHeadlessTest(unittest.TestCase):
): ):
eval_vla._run_eval(cfg) eval_vla._run_eval(cfg)
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True) make_env.assert_called_once_with("sim_air_insert_socket_peg", headless=True)
ring_bar_sampler.assert_called_once_with() socket_peg_sampler.assert_called_once_with()
execute_policy_action.assert_called_once() execute_policy_action.assert_called_once()
self.assertEqual(fake_env.reset_calls, [sampled_task_state]) self.assertEqual(fake_env.reset_calls, [sampled_task_state])

View File

@@ -59,15 +59,15 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf}) self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls)) self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self): def test_bidianamed_socket_peg_resolves_robot_asset_paths_independent_of_cwd(self):
BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None) BiDianaMedSocketPeg = getattr(diana_med, 'BiDianaMedSocketPeg', None)
self.assertIsNotNone( self.assertIsNotNone(
BiDianaMedRingBar, BiDianaMedSocketPeg,
'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar', 'Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg',
) )
repo_root = Path(__file__).resolve().parents[1] repo_root = Path(__file__).resolve().parents[1]
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml' expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml'
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf' expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
xml_calls = [] xml_calls = []
@@ -89,7 +89,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
'roboimi.assets.robots.arm_base.KDL_utils', 'roboimi.assets.robots.arm_base.KDL_utils',
_FakeKDL, _FakeKDL,
): ):
BiDianaMedRingBar() BiDianaMedSocketPeg()
finally: finally:
os.chdir(previous_cwd) os.chdir(previous_cwd)