From 5c5cb299e975232032d97a35be209bd0e05c22f3 Mon Sep 17 00:00:00 2001 From: Logic Date: Sat, 2 May 2026 17:34:43 +0800 Subject: [PATCH] feat(sim): switch air insert task to socket peg --- ..._bar_ee.xml => bi_diana_socket_peg_ee.xml} | 4 +- .../DianaMed/ring_bar_objects.xml | 28 - .../DianaMed/socket_peg_objects.xml | 19 + .../manipulators/DianaMed/table_square.xml | 2 +- roboimi/assets/robots/diana_med.py | 6 +- roboimi/demos/diana_air_insert_policy.py | 217 +++++-- roboimi/demos/diana_record_sim_episodes.py | 18 +- roboimi/demos/vla_scripts/eval_vla.py | 6 +- roboimi/envs/double_air_insert_env.py | 227 ++----- roboimi/envs/double_base.py | 16 +- roboimi/envs/double_pos_ctrl_env.py | 40 +- roboimi/utils/act_ex_utils.py | 25 +- roboimi/utils/constants.py | 16 +- tests/test_air_insert_env.py | 554 +++++++++--------- tests/test_eval_vla_headless.py | 34 +- tests/test_robot_asset_paths.py | 12 +- 16 files changed, 594 insertions(+), 630 deletions(-) rename roboimi/assets/models/manipulators/DianaMed/{bi_diana_ring_bar_ee.xml => bi_diana_socket_peg_ee.xml} (62%) delete mode 100644 roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml create mode 100644 roboimi/assets/models/manipulators/DianaMed/socket_peg_objects.xml diff --git a/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml b/roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml similarity index 62% rename from roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml rename to roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml index 38c21f8..e532054 100644 --- a/roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml +++ b/roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml @@ -1,6 +1,6 @@ - + - + diff --git a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml deleted file mode 100644 index 196ea02..0000000 --- a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - - diff --git a/roboimi/assets/models/manipulators/DianaMed/socket_peg_objects.xml b/roboimi/assets/models/manipulators/DianaMed/socket_peg_objects.xml new file mode 100644 index 0000000..642bd78 --- /dev/null +++ b/roboimi/assets/models/manipulators/DianaMed/socket_peg_objects.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + + + diff --git a/roboimi/assets/models/manipulators/DianaMed/table_square.xml b/roboimi/assets/models/manipulators/DianaMed/table_square.xml index 9d36f5b..d1127d0 100644 --- a/roboimi/assets/models/manipulators/DianaMed/table_square.xml +++ b/roboimi/assets/models/manipulators/DianaMed/table_square.xml @@ -7,7 +7,7 @@ - + diff --git a/roboimi/assets/robots/diana_med.py b/roboimi/assets/robots/diana_med.py index 04ff249..691837e 100644 --- a/roboimi/assets/robots/diana_med.py +++ b/roboimi/assets/robots/diana_med.py @@ -92,12 +92,12 @@ class BiDianaMed(ArmBase): return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0]) -class BiDianaMedRingBar(ArmBase): +class BiDianaMedSocketPeg(ArmBase): def __init__(self): super().__init__( - name="Bidiana_ring_bar", + name="Bidiana_socket_peg", urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf", - xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml", + xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml", gripper=None ) self.left_arm = self.Arm(self, 'single', self.urdf_path) diff --git a/roboimi/demos/diana_air_insert_policy.py b/roboimi/demos/diana_air_insert_policy.py index 30511bb..9d72f46 100644 --- a/roboimi/demos/diana_air_insert_policy.py +++ b/roboimi/demos/diana_air_insert_policy.py @@ -5,16 +5,39 @@ from roboimi.demos.diana_policy import PolicyBase class TestAirInsertPolicy(PolicyBase): - @staticmethod - def _action_xyz_for_object_center(object_center, ee_quat, object_offset_local): - return ( - np.asarray(object_center, dtype=np.float64) - - np.asarray(Quaternion(ee_quat).rotate(object_offset_local), dtype=np.float64) - ) + ACTION_OBJECT_Z_OFFSET = 0.078 + SOCKET_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64) + PEG_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64) + SOCKET_OUTER_GRASP_STRATEGY = "socket_outer" + LEGACY_GRASP_STRATEGY = "legacy" + SOCKET_HOLD_Z = 0.85 + PEG_INSERT_START_OFFSET = np.array([0.105, 0.0, 0.0], dtype=np.float64) + INSERT_START_T = 650 + INSERT_END_T = 700 + LEFT_SOCKET_GRIPPER_CLOSED = -70 + RIGHT_PEG_GRIPPER_CLOSED = -100 + SOCKET_APPROACH_Z = 1.05 + EPISODE_END_T = 1000 + + def __init__(self, inject_noise=False, grasp_strategy=SOCKET_OUTER_GRASP_STRATEGY): + super().__init__(inject_noise=inject_noise) + valid_strategies = { + self.SOCKET_OUTER_GRASP_STRATEGY, + self.LEGACY_GRASP_STRATEGY, + } + if grasp_strategy not in valid_strategies: + raise ValueError( + f"Unsupported air insert grasp_strategy={grasp_strategy!r}; " + f"expected one of {sorted(valid_strategies)}" + ) + self.grasp_strategy = grasp_strategy def generate_trajectory(self, task_state): - ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64) - bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64) + return self._generate_socket_peg_trajectory(task_state) + + def _generate_socket_peg_trajectory(self, task_state): + socket_xyz = np.asarray(task_state["socket_pos"], dtype=np.float64) + peg_xyz = np.asarray(task_state["peg_pos"], dtype=np.float64) init_mocap_pose_left = np.array( [ @@ -44,63 +67,137 @@ class TestAirInsertPolicy(PolicyBase): left_init_quat = Quaternion(init_mocap_pose_left[3:]) right_init_quat = Quaternion(init_mocap_pose_right[3:]) - object_offset_local = np.array([0.0, 0.0, -0.09], dtype=np.float64) - left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements - left_hold_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=-90).elements - right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements - insert_quat_local = Quaternion([-0.50019721, 0.50020088, 0.49980484, 0.49979692]) - right_insert_quat = np.array( - (Quaternion(left_hold_quat) * insert_quat_local).elements, + left_pick_quat = ( + left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45) + ).elements + right_pick_quat = ( + right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45) + ).elements + + socket_hold_action = np.array( + [socket_xyz[0] - 0.078, socket_xyz[1], self.SOCKET_HOLD_Z], dtype=np.float64 + ) + + peg_init_xyz = peg_xyz + np.array( + [0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET + 0.01] + ) + peg_lift_center = np.array( + [peg_xyz[0] + 0.078, socket_hold_action[1], self.SOCKET_HOLD_Z - 0.01], + dtype=np.float64, + ) + # The front camera looks along +Y, so visual right-to-left insertion is + # world +X -> -X. With the socket XML in identity orientation, its + # tunnel axis is local/world X, so the peg approaches from +X and stops + # when its leading face reaches the socket's internal pin. + peg_insert_end_center = np.array( + [ + socket_hold_action[0] + 0.078 * 2 + 0.04 + 0.06 - 0.01, + socket_hold_action[1], + self.SOCKET_HOLD_Z - 0.01, + ], dtype=np.float64, ) - meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64) - ring_stabilize_center = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64) - ring_hold_center = meet_xyz + np.array([-0.10, 0.05, -0.16], dtype=np.float64) - bar_reorient_center = bar_xyz + np.array([0.0, 0.0, 0.16], dtype=np.float64) - bar_wait_center = ring_hold_center + np.array([0.05, -0.18, 0.0], dtype=np.float64) - bar_insert_start_center = ring_hold_center + np.array([0.0, -0.075, 0.0], dtype=np.float64) - bar_insert_end_center = ring_hold_center + np.array([0.0, 0.075, 0.0], dtype=np.float64) - - left_stabilize_xyz = self._action_xyz_for_object_center( - ring_stabilize_center, left_pick_quat, object_offset_local - ) - left_hold_xyz = self._action_xyz_for_object_center( - ring_hold_center, left_hold_quat, object_offset_local - ) - right_reorient_xyz = self._action_xyz_for_object_center( - bar_reorient_center, right_insert_quat, object_offset_local - ) - right_wait_xyz = self._action_xyz_for_object_center( - bar_wait_center, right_insert_quat, object_offset_local - ) - right_insert_start_xyz = self._action_xyz_for_object_center( - bar_insert_start_center, right_insert_quat, object_offset_local - ) - right_insert_end_xyz = self._action_xyz_for_object_center( - bar_insert_end_center, right_insert_quat, object_offset_local - ) - self.left_trajectory = [ - {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, - {"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100}, - {"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, - {"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, - {"t": 260, "xyz": self._action_xyz_for_object_center(ring_xyz + np.array([0.0, 0.0, 0.24]), left_pick_quat, object_offset_local), "quat": left_pick_quat, "gripper": -100}, - {"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100}, - {"t": 460, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, - {"t": 700, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, + { + "t": 1, + "xyz": init_mocap_pose_left[:3], + "quat": init_mocap_pose_left[3:], + "gripper": 100, + }, + { + "t": 130, + "xyz": socket_xyz + + np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]), + "quat": left_pick_quat, + "gripper": 100, + }, + { + "t": 180, + "xyz": socket_xyz + + np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]), + "quat": left_pick_quat, + "gripper": self.LEFT_SOCKET_GRIPPER_CLOSED, + }, + { + "t": 450, + "xyz": socket_hold_action, + "quat": left_pick_quat, + "gripper": self.LEFT_SOCKET_GRIPPER_CLOSED, + }, + { + "t": 750, + "xyz": socket_hold_action, + "quat": left_pick_quat, + "gripper": self.LEFT_SOCKET_GRIPPER_CLOSED, + }, + { + "t": self.EPISODE_END_T, + "xyz": socket_hold_action, + "quat": left_pick_quat, + "gripper": self.LEFT_SOCKET_GRIPPER_CLOSED, + }, ] self.right_trajectory = [ - {"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100}, - {"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100}, - {"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, - {"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, - {"t": 240, "xyz": self._action_xyz_for_object_center(bar_xyz + np.array([0.0, 0.0, 0.12]), right_pick_quat, object_offset_local), "quat": right_pick_quat, "gripper": -100}, - {"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100}, - {"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100}, - {"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, - {"t": 690, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, - {"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, + { + "t": 1, + "xyz": init_mocap_pose_right[:3], + "quat": init_mocap_pose_right[3:], + "gripper": 100, + }, + { + "t": 80, + "xyz": peg_init_xyz, + "quat": right_pick_quat, + "gripper": 100, + }, + { + "t": 150, + "xyz": peg_init_xyz, + "quat": right_pick_quat, + "gripper": 100, + }, + { + "t": 180, + "xyz": peg_init_xyz, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": 450, + "xyz": peg_init_xyz, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": 550, + "xyz": peg_lift_center, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": self.INSERT_START_T, + "xyz": peg_lift_center, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": self.INSERT_END_T, + "xyz": peg_insert_end_center, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": 750, + "xyz": peg_insert_end_center, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, + { + "t": self.EPISODE_END_T, + "xyz": peg_insert_end_center, + "quat": right_pick_quat, + "gripper": self.RIGHT_PEG_GRIPPER_CLOSED, + }, ] diff --git a/roboimi/demos/diana_record_sim_episodes.py b/roboimi/demos/diana_record_sim_episodes.py index 19a9a86..c712031 100644 --- a/roboimi/demos/diana_record_sim_episodes.py +++ b/roboimi/demos/diana_record_sim_episodes.py @@ -5,7 +5,7 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy from roboimi.demos.diana_policy import TestPickAndTransferPolicy import cv2 -from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose +from roboimi.utils.act_ex_utils import sample_air_insert_socket_peg_state, sample_transfer_pose from roboimi.utils.constants import SIM_TASK_CONFIGS from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter @@ -17,16 +17,18 @@ DATASET_DIR = HOME_PATH + '/dataset' def sample_task_state(task_name): if task_name == 'sim_transfer': return sample_transfer_pose() - if task_name == 'sim_air_insert_ring_bar': - return sample_air_insert_ring_bar_state() + if task_name == 'sim_air_insert_socket_peg': + return sample_air_insert_socket_peg_state() raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') -def make_policy(task_name, inject_noise=False): +def make_policy(task_name, inject_noise=False, grasp_strategy=None): if task_name == 'sim_transfer': return TestPickAndTransferPolicy(inject_noise) - if task_name == 'sim_air_insert_ring_bar': - return TestAirInsertPolicy(inject_noise) + if task_name == 'sim_air_insert_socket_peg': + if grasp_strategy is None: + return TestAirInsertPolicy(inject_noise) + return TestAirInsertPolicy(inject_noise, grasp_strategy=grasp_strategy) raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}') @@ -37,9 +39,9 @@ def main(task_name='sim_transfer'): inject_noise = False episode_len = task_cfg['episode_len'] - camera_names = ['angle', 'r_vis', 'top', 'front'] + camera_names = ['left_side', 'r_vis', 'top', 'front'] image_size = (256, 256) - if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}: + if task_name in {'sim_transfer', 'sim_air_insert_socket_peg'}: print(task_name) else: raise NotImplementedError diff --git a/roboimi/demos/vla_scripts/eval_vla.py b/roboimi/demos/vla_scripts/eval_vla.py index 265e36a..89d421f 100644 --- a/roboimi/demos/vla_scripts/eval_vla.py +++ b/roboimi/demos/vla_scripts/eval_vla.py @@ -27,7 +27,7 @@ from einops import rearrange from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.utils.act_ex_utils import ( - sample_air_insert_ring_bar_state, + sample_air_insert_socket_peg_state, sample_transfer_pose, ) from roboimi.vla.eval_utils import execute_policy_action @@ -489,8 +489,8 @@ def _close_env(env): def _sample_task_reset_state(task_name: str): - if task_name == 'sim_air_insert_ring_bar': - return sample_air_insert_ring_bar_state() + if task_name == 'sim_air_insert_socket_peg': + return sample_air_insert_socket_peg_state() if 'sim_transfer' in task_name: return sample_transfer_pose() raise NotImplementedError(f'Unsupported eval task reset sampling: {task_name}') diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py index 1050fdf..f1db21d 100644 --- a/roboimi/envs/double_air_insert_env.py +++ b/roboimi/envs/double_air_insert_env.py @@ -1,23 +1,19 @@ import copy as cp import time -import mujoco as mj import numpy as np from roboimi.envs.double_base import DualDianaMed from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl -RING_JOINT_NAME = "ring_block_joint" -BAR_JOINT_NAME = "bar_block_joint" -REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat") -RING_GEOM_NAMES = ( - "ring_block_north", - "ring_block_south", - "ring_block_east", - "ring_block_west", -) -BAR_GEOM_NAMES = ("bar_block",) +SOCKET_JOINT_NAME = "blue_socket_joint" +PEG_JOINT_NAME = "red_peg_joint" +REQUIRED_TASK_STATE_KEYS = ("socket_pos", "socket_quat", "peg_pos", "peg_quat") +SOCKET_GEOM_NAMES = ("socket-1", "socket-2", "socket-3", "socket-4") +SOCKET_SUCCESS_GEOM_NAMES = ("pin",) +SOCKET_BODY_GEOM_NAMES = SOCKET_GEOM_NAMES + SOCKET_SUCCESS_GEOM_NAMES +PEG_GEOM_NAMES = ("red_peg",) LEFT_GRIPPER_GEOM_NAMES = ( "l_finger_left", "r_finger_left", @@ -25,6 +21,8 @@ LEFT_GRIPPER_GEOM_NAMES = ( "r_fingertip_g0_left", "l_fingerpad_g0_left", "r_fingerpad_g0_left", + "l_fingertip_g0_vis_left", + "r_fingertip_g0_vis_left", ) RIGHT_GRIPPER_GEOM_NAMES = ( "l_finger_right", @@ -33,12 +31,10 @@ RIGHT_GRIPPER_GEOM_NAMES = ( "r_fingertip_g0_right", "l_fingerpad_g0_right", "r_fingerpad_g0_right", + "l_fingertip_g0_vis_right", + "r_fingertip_g0_vis_right", ) TABLE_GEOM_NAME = "table" -RING_APERTURE_HALF_WIDTH = 0.016 -RING_HALF_THICKNESS = 0.009 -BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64) -SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0 def _set_free_joint_pose(joint, position, quat): @@ -46,29 +42,29 @@ def _set_free_joint_pose(joint, position, quat): joint.qpos[3:7] = np.asarray(quat, dtype=np.float64) -def set_ring_bar_task_state(mj_data, task_state): +def set_socket_peg_task_state(mj_data, task_state): if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS: raise ValueError( "task_state must be an ordered dict-like mapping with keys " - "ring_pos, ring_quat, bar_pos, bar_quat" + "socket_pos, socket_quat, peg_pos, peg_quat" ) _set_free_joint_pose( - mj_data.joint(RING_JOINT_NAME), - task_state["ring_pos"], - task_state["ring_quat"], + mj_data.joint(SOCKET_JOINT_NAME), + task_state["socket_pos"], + task_state["socket_quat"], ) _set_free_joint_pose( - mj_data.joint(BAR_JOINT_NAME), - task_state["bar_pos"], - task_state["bar_quat"], + mj_data.joint(PEG_JOINT_NAME), + task_state["peg_pos"], + task_state["peg_quat"], ) -def get_ring_bar_env_state(mj_data): - ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64)) - bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64)) - return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64) +def get_socket_peg_env_state(mj_data): + socket_qpos = cp.deepcopy(np.asarray(mj_data.joint(SOCKET_JOINT_NAME).qpos[:7], dtype=np.float64)) + peg_qpos = cp.deepcopy(np.asarray(mj_data.joint(PEG_JOINT_NAME).qpos[:7], dtype=np.float64)) + return np.concatenate([socket_qpos, peg_qpos], dtype=np.float64) def _normalize_contact_pairs(contact_pairs): @@ -87,91 +83,29 @@ def _object_is_airborne(contact_set, object_geom_names): return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,)) -def _quat_to_rotation_matrix(quat): - quat = np.asarray(quat, dtype=np.float64) - quat /= np.linalg.norm(quat) - w, x, y, z = quat - return np.array( - [ - [1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)], - [2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)], - [2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)], - ], - dtype=np.float64, - ) +def peg_inserted_into_socket(contact_pairs): + contact_set = _normalize_contact_pairs(contact_pairs) + return frozenset((PEG_GEOM_NAMES[0], SOCKET_SUCCESS_GEOM_NAMES[0])) in contact_set -def _quat_multiply(lhs, rhs): - lhs = np.asarray(lhs, dtype=np.float64) - rhs = np.asarray(rhs, dtype=np.float64) - w1, x1, y1, z1 = lhs - w2, x2, y2, z2 = rhs - return np.array( - [ - w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, - w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, - w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, - w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, - ], - dtype=np.float64, - ) - - -def _quat_inverse(quat): - quat = np.asarray(quat, dtype=np.float64) - norm_sq = float(np.dot(quat, quat)) - return np.array([quat[0], -quat[1], -quat[2], -quat[3]], dtype=np.float64) / norm_sq - - -def _split_env_state(env_state): - env_state = np.asarray(env_state, dtype=np.float64) - if env_state.shape != (14,): - raise ValueError(f"env_state must have shape (14,), got {env_state.shape}") - return ( - env_state[:3], - env_state[3:7], - env_state[7:10], - env_state[10:14], - ) - - -def bar_fully_inserted_through_ring(env_state): - ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state) - ring_rot = _quat_to_rotation_matrix(ring_quat) - bar_rot = _quat_to_rotation_matrix(bar_quat) - - bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos) - bar_rot_in_ring = ring_rot.T @ bar_rot - projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES - - spans_ring_thickness = ( - bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS - and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS - ) - fits_aperture = ( - abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH - and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH - ) - return bool(spans_ring_thickness and fits_aperture) - - -def compute_air_insert_reward(contact_pairs, env_state): +def compute_air_insert_reward(contact_pairs, env_state=None): + del env_state # kept for API compatibility with rollout/eval code paths contact_set = _normalize_contact_pairs(contact_pairs) reward = 0 - if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES): + if _has_any_object_contact(contact_set, SOCKET_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES): reward += 1 - if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES): + if _has_any_object_contact(contact_set, PEG_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES): reward += 1 - ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES) - bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES) - if ring_airborne: + socket_airborne = _object_is_airborne(contact_set, SOCKET_BODY_GEOM_NAMES) + peg_airborne = _object_is_airborne(contact_set, PEG_GEOM_NAMES) + if socket_airborne: reward += 1 - if bar_airborne: + if peg_airborne: reward += 1 - if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state): + if socket_airborne and peg_airborne and peg_inserted_into_socket(contact_pairs): reward += 1 return reward @@ -181,33 +115,19 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_reward = 5 - self._scripted_ring_grasped = False - self._scripted_bar_grasped = False - self._scripted_ring_pos_offset_local = None - self._scripted_bar_pos_offset_local = None - self._scripted_ring_quat_offset = None - self._scripted_bar_quat_offset = None - self._air_insert_step_count = 0 def reset(self, task_state): - self._scripted_ring_grasped = False - self._scripted_bar_grasped = False - self._scripted_ring_pos_offset_local = None - self._scripted_bar_pos_offset_local = None - self._scripted_ring_quat_offset = None - self._scripted_bar_quat_offset = None - self._air_insert_step_count = 0 - set_ring_bar_task_state(self.mj_data, task_state) + set_socket_peg_task_state(self.mj_data, task_state) DualDianaMed.reset(self) self.top = None - self.angle = None + self.left_side = None self.r_vis = None self.front = None self.cam_flage = True while self.cam_flage: if ( type(self.top) == type(None) - or type(self.angle) == type(None) + or type(self.left_side) == type(None) or type(self.r_vis) == type(None) or type(self.front) == type(None) ): @@ -217,76 +137,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): def step(self, action=np.zeros(16)): super().step(action) - self._update_scripted_grasped_objects(action) self.rew = self._get_reward() self.obs = self._get_obs() - self._air_insert_step_count += 1 - - def _update_scripted_grasped_objects(self, action): - if ( - action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD - and self._air_insert_step_count >= 180 - and not self._scripted_ring_grasped - ): - self._scripted_ring_grasped = True - self._attach_scripted_object( - object_joint_name=RING_JOINT_NAME, - ee_pos=action[:3], - ee_quat=action[3:7], - pos_attr="_scripted_ring_pos_offset_local", - quat_attr="_scripted_ring_quat_offset", - ) - if ( - action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD - and self._air_insert_step_count >= 180 - and not self._scripted_bar_grasped - ): - self._scripted_bar_grasped = True - self._attach_scripted_object( - object_joint_name=BAR_JOINT_NAME, - ee_pos=action[7:10], - ee_quat=action[10:14], - pos_attr="_scripted_bar_pos_offset_local", - quat_attr="_scripted_bar_quat_offset", - ) - - if self._scripted_ring_grasped: - self._update_scripted_object_pose( - object_joint_name=RING_JOINT_NAME, - ee_pos=action[:3], - ee_quat=action[3:7], - pos_offset_local=self._scripted_ring_pos_offset_local, - quat_offset=self._scripted_ring_quat_offset, - ) - if self._scripted_bar_grasped: - self._update_scripted_object_pose( - object_joint_name=BAR_JOINT_NAME, - ee_pos=action[7:10], - ee_quat=action[10:14], - pos_offset_local=self._scripted_bar_pos_offset_local, - quat_offset=self._scripted_bar_quat_offset, - ) - if self._scripted_ring_grasped or self._scripted_bar_grasped: - mj.mj_forward(self.mj_model, self.mj_data) - - def _attach_scripted_object(self, object_joint_name, ee_pos, ee_quat, pos_attr, quat_attr): - ee_pos = np.asarray(ee_pos, dtype=np.float64) - ee_quat = np.asarray(ee_quat, dtype=np.float64) - object_qpos = np.asarray(self.mj_data.joint(object_joint_name).qpos[:7], dtype=np.float64) - ee_rot = _quat_to_rotation_matrix(ee_quat) - setattr(self, pos_attr, ee_rot.T @ (object_qpos[:3] - ee_pos)) - setattr(self, quat_attr, _quat_multiply(_quat_inverse(ee_quat), object_qpos[3:7])) - - def _update_scripted_object_pose(self, object_joint_name, ee_pos, ee_quat, pos_offset_local, quat_offset): - ee_pos = np.asarray(ee_pos, dtype=np.float64) - ee_quat = np.asarray(ee_quat, dtype=np.float64) - ee_rot = _quat_to_rotation_matrix(ee_quat) - object_pos = ee_pos + ee_rot @ np.asarray(pos_offset_local, dtype=np.float64) - object_quat = _quat_multiply(ee_quat, quat_offset) - _set_free_joint_pose(self.mj_data.joint(object_joint_name), object_pos, object_quat) def get_env_state(self): - return get_ring_bar_env_state(self.mj_data) + return get_socket_peg_env_state(self.mj_data) def _get_reward(self): contact_pairs = [] @@ -296,8 +151,4 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): contact_pairs.append( (self.getID2Name("geom", geom1), self.getID2Name("geom", geom2)) ) - if self._scripted_ring_grasped: - contact_pairs.append(("ring_block_south", "l_fingertip_g0_left")) - if self._scripted_bar_grasped: - contact_pairs.append(("bar_block", "r_fingertip_g0_right")) return compute_air_insert_reward(contact_pairs, self.get_env_state()) diff --git a/roboimi/envs/double_base.py b/roboimi/envs/double_base.py index 1089d3a..02b1686 100644 --- a/roboimi/envs/double_base.py +++ b/roboimi/envs/double_base.py @@ -52,7 +52,7 @@ class DualDianaMed(MujocoEnv): self.r_vis = None self.l_vis = None self.top = None - self.angle = None + self.left_side = None self.front = None self.obs = None @@ -166,7 +166,7 @@ class DualDianaMed(MujocoEnv): obs['action'] = self.compute_qpos obs['images'] = dict() obs['images']['top'] = self.top - obs['images']['angle'] = self.angle + obs['images']['left_side'] = self.left_side obs['images']['r_vis'] = self.r_vis obs['images']['l_vis'] = self.l_vis obs['images']['front'] = self.front @@ -176,7 +176,7 @@ class DualDianaMed(MujocoEnv): obs = collections.OrderedDict() obs['images'] = dict() obs['images']['top'] = self.top - obs['images']['angle'] = self.angle + obs['images']['left_side'] = self.left_side obs['images']['r_vis'] = self.r_vis obs['images']['l_vis'] = self.l_vis obs['images']['front'] = self.front @@ -199,8 +199,8 @@ class DualDianaMed(MujocoEnv): def cam_view(self): if self.cam == 'top': return self.top - elif self.cam == 'angle': - return self.angle + elif self.cam == 'left_side': + return self.left_side elif self.cam == 'r_vis': return self.r_vis elif self.cam == 'l_vis': @@ -226,9 +226,9 @@ class DualDianaMed(MujocoEnv): img_renderer.update_scene(self.mj_data,camera="top") self.top = img_renderer.render() self.top = self.top[:, :, ::-1] - img_renderer.update_scene(self.mj_data,camera="angle") - self.angle = img_renderer.render() - self.angle = self.angle[:, :, ::-1] + img_renderer.update_scene(self.mj_data,camera="left_side") + self.left_side = img_renderer.render() + self.left_side = self.left_side[:, :, ::-1] img_renderer.update_scene(self.mj_data,camera="front") self.front = img_renderer.render() self.front = self.front[:, :, ::-1] diff --git a/roboimi/envs/double_pos_ctrl_env.py b/roboimi/envs/double_pos_ctrl_env.py index 31e8c86..341fde0 100644 --- a/roboimi/envs/double_pos_ctrl_env.py +++ b/roboimi/envs/double_pos_ctrl_env.py @@ -34,19 +34,19 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): is_interpolate=is_interpolate, cam_view=cam_view ) - + self.max_reward = 4 - + self.cam_start() - + def step(self,action=np.zeros(16)): action_left = self.ik_solve(action[:3],action[3:7],self.arm_left) action_right = self.ik_solve(action[7:10],action[10:14],self.arm_right) action = np.hstack((action_left,action_right,action[14:])) super().step(action) self.rew = self._get_reward() - + def step_jnt(self,action): super().step(action) @@ -63,8 +63,8 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): return Arm.kdl_solver.ikSolver(p_goal, mat_goal, Arm.arm_qpos) def reset(self,box_pos): - - self.mj_data.joint('red_box_joint').qpos[0] = box_pos[0] + + self.mj_data.joint('red_box_joint').qpos[0] = box_pos[0] self.mj_data.joint('red_box_joint').qpos[1] = box_pos[1] self.mj_data.joint('red_box_joint').qpos[2] = box_pos[2] self.mj_data.joint('red_box_joint').qpos[3] = 1.0 @@ -73,22 +73,22 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): self.mj_data.joint('red_box_joint').qpos[6] = 0.0 super().reset() self.top = None - self.angle = None + self.left_side = None self.r_vis = None self.front = None self.cam_flage = True t=0 while self.cam_flage: - if(type(self.top)==type(None) - or type(self.angle)==type(None) + if(type(self.top)==type(None) + or type(self.left_side)==type(None) or type(self.r_vis)==type(None) or type(self.front)==type(None)): time.sleep(0.001) t+=1 else: self.cam_flage=False - - + + def preStep(self, action): if isinstance(action,np.ndarray) and len(action)==16: @@ -101,7 +101,7 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): for i in range(3): box_pose[i] = cp.deepcopy(self.mj_data.joint('red_box_joint').qpos[i]) return box_pose - + def _get_reward(self): all_contact_pairs = [] @@ -124,26 +124,26 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed): reward = 0 if touch_right_gripper and not touch_table: reward = 1 - if touch_right_gripper and not box_touch_table: + if touch_right_gripper and not box_touch_table: reward = 2 if touch_left_gripper: # attempted transfer reward = 3 if touch_left_gripper and not box_touch_table: # successful transfer reward = 4 return reward - + def make_sim_env(task_name, headless=False): - if task_name == 'sim_air_insert_ring_bar': - from roboimi.assets.robots.diana_med import BiDianaMedRingBar + if task_name == 'sim_air_insert_socket_peg': + from roboimi.assets.robots.diana_med import BiDianaMedSocketPeg from roboimi.envs.double_air_insert_env import DualDianaMed_Air_Insert env = DualDianaMed_Air_Insert( - robot=BiDianaMedRingBar(), + robot=BiDianaMedSocketPeg(), is_render=not headless, control_freq=30, is_interpolate=True, - cam_view='angle' + cam_view='left_side' ) return env if 'sim_transfer' in task_name: @@ -153,7 +153,7 @@ def make_sim_env(task_name, headless=False): is_render=not headless, control_freq=30, is_interpolate=True, - cam_view='angle' + cam_view='left_side' ) return env else: @@ -179,4 +179,4 @@ if __name__ == "__main__": env.step(action) if env.is_render: env.render() - + diff --git a/roboimi/utils/act_ex_utils.py b/roboimi/utils/act_ex_utils.py index 6afc0bb..5ca0ba3 100644 --- a/roboimi/utils/act_ex_utils.py +++ b/roboimi/utils/act_ex_utils.py @@ -39,19 +39,20 @@ def sample_transfer_pose(): return box_position -def sample_air_insert_ring_bar_state(): - ring_position = np.random.uniform( - low=np.array([-0.20, 0.70, 0.47], dtype=np.float32), - high=np.array([-0.05, 1.00, 0.47], dtype=np.float32), +def sample_air_insert_socket_peg_state(): + socket_position = np.random.uniform( + low=np.array([-0.14, 0.89, 0.472], dtype=np.float32), + high=np.array([-0.10, 0.94, 0.472], dtype=np.float32), ) - bar_position = np.random.uniform( - low=np.array([0.05, 0.70, 0.47], dtype=np.float32), - high=np.array([0.20, 1.00, 0.47], dtype=np.float32), + peg_position = np.random.uniform( + low=np.array([0.10, 0.85, 0.46], dtype=np.float32), + high=np.array([0.16, 0.94, 0.46], dtype=np.float32), ) - fixed_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + socket_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) + peg_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32) return { - "ring_pos": ring_position.astype(np.float32, copy=False), - "ring_quat": fixed_quat.copy(), - "bar_pos": bar_position.astype(np.float32, copy=False), - "bar_quat": fixed_quat.copy(), + "socket_pos": socket_position.astype(np.float32, copy=False), + "socket_quat": socket_quat, + "peg_pos": peg_position.astype(np.float32, copy=False), + "peg_quat": peg_quat, } diff --git a/roboimi/utils/constants.py b/roboimi/utils/constants.py index 10158e7..0096f94 100644 --- a/roboimi/utils/constants.py +++ b/roboimi/utils/constants.py @@ -23,10 +23,10 @@ SIM_TASK_CONFIGS = { 'camera_names': ['top','r_vis','front'], 'xml_dir': HOME_PATH + '/assets' }, - 'sim_air_insert_ring_bar': { - 'dataset_dir': DATASET_DIR + '/sim_air_insert_ring_bar', + 'sim_air_insert_socket_peg': { + 'dataset_dir': DATASET_DIR + '/sim_air_insert_socket_peg', 'num_episodes': 20, - 'episode_len': 700, + 'episode_len': 1000, 'camera_names': ['top', 'r_vis', 'front'], 'xml_dir': HOME_PATH + '/assets' }, @@ -59,13 +59,3 @@ PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x)) - -MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) -PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) - -MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE -MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)) -PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE -PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)) - -MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2 diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py index 59ba1ed..5ff33a7 100644 --- a/tests/test_air_insert_env.py +++ b/tests/test_air_insert_env.py @@ -1,6 +1,9 @@ import importlib +import inspect +import pathlib import unittest from unittest import mock +import xml.etree.ElementTree as ET import numpy as np @@ -9,83 +12,80 @@ from roboimi.utils import act_ex_utils from roboimi.utils.constants import SIM_TASK_CONFIGS -class AirInsertTaskRegistrationTest(unittest.TestCase): - def test_sim_task_configs_registers_air_insert_ring_bar(self): - self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS) +TASK_NAME = "sim_air_insert_socket_peg" - def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self): - sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) + +class AirInsertTaskRegistrationTest(unittest.TestCase): + def test_sim_task_configs_registers_air_insert_socket_peg(self): + self.assertIn(TASK_NAME, SIM_TASK_CONFIGS) + self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS) + self.assertGreaterEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 1000) + self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg")) + + def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self): + sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None) self.assertIsNotNone( sampler, - "Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()", + "Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()", + ) + self.assertFalse( + hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"), + "air insert sampler should use socket/peg naming after the task rename", ) task_state = sampler() self.assertEqual( list(task_state.keys()), - ["ring_pos", "ring_quat", "bar_pos", "bar_quat"], + ["socket_pos", "socket_quat", "peg_pos", "peg_quat"], ) - self.assertEqual(task_state["ring_pos"].shape, (3,)) - self.assertEqual(task_state["ring_quat"].shape, (4,)) - self.assertEqual(task_state["bar_pos"].shape, (3,)) - self.assertEqual(task_state["bar_quat"].shape, (4,)) + self.assertEqual(task_state["socket_pos"].shape, (3,)) + self.assertEqual(task_state["socket_quat"].shape, (4,)) + self.assertEqual(task_state["peg_pos"].shape, (3,)) + self.assertEqual(task_state["peg_quat"].shape, (4,)) - def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self): - sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None) - self.assertIsNotNone( - sampler, - "Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()", - ) + def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self): + sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None) + self.assertIsNotNone(sampler) task_state = sampler() - np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0])) - np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0])) - self.assertGreaterEqual(task_state["ring_pos"][0], -0.20) - self.assertLessEqual(task_state["ring_pos"][0], -0.05) - self.assertGreaterEqual(task_state["ring_pos"][1], 0.70) - self.assertLessEqual(task_state["ring_pos"][1], 1.00) - self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47) - self.assertGreaterEqual(task_state["bar_pos"][0], 0.05) - self.assertLessEqual(task_state["bar_pos"][0], 0.20) - self.assertGreaterEqual(task_state["bar_pos"][1], 0.70) - self.assertLessEqual(task_state["bar_pos"][1], 1.00) - self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47) - - def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self): - try: - air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - except Exception as exc: - self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}") + np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) + np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) + self.assertGreaterEqual(task_state["socket_pos"][0], -0.20) + self.assertLessEqual(task_state["socket_pos"][0], -0.05) + self.assertGreaterEqual(task_state["socket_pos"][1], 0.70) + self.assertLessEqual(task_state["socket_pos"][1], 1.00) + self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472) + self.assertGreaterEqual(task_state["peg_pos"][0], 0.05) + self.assertLessEqual(task_state["peg_pos"][0], 0.20) + self.assertGreaterEqual(task_state["peg_pos"][1], 0.70) + self.assertLessEqual(task_state["peg_pos"][1], 1.00) + self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46) + def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) - self.assertIsNotNone( - air_insert_cls, - "Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert", - ) + self.assertIsNotNone(air_insert_cls) diana_med = importlib.import_module("roboimi.assets.robots.diana_med") - ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None) + socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None) self.assertIsNotNone( - ring_bar_robot_cls, - "Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar", + socket_peg_robot_cls, + "Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg", ) fake_env = object() with mock.patch.object( diana_med, - "BiDianaMedRingBar", + "BiDianaMedSocketPeg", return_value="robot", ), mock.patch.object( air_insert_env, "DualDianaMed_Air_Insert", return_value=fake_env, ) as env_cls: - try: - env = make_sim_env("sim_air_insert_ring_bar", headless=True) - except Exception as exc: - self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}") + env = make_sim_env(TASK_NAME, headless=True) self.assertIs(env, fake_env) env_cls.assert_called_once_with( @@ -93,21 +93,36 @@ class AirInsertTaskRegistrationTest(unittest.TestCase): is_render=False, control_freq=30, is_interpolate=True, - cam_view="angle", + cam_view="left_side", ) + def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self): + xml_path = ( + pathlib.Path(__file__).resolve().parents[1] + / "roboimi/assets/models/manipulators/DianaMed/table_square.xml" + ) + root = ET.parse(xml_path).getroot() + cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")} + + self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera") + self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera") + left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ") + self.assertLess(float(left_side_pos[0]), 0.0) + self.assertEqual(cameras["left_side"].get("mode"), "targetbody") + self.assertEqual(cameras["left_side"].get("target"), "table") + class AirInsertResetAndStateHelpersTest(unittest.TestCase): - def test_set_ring_bar_task_state_writes_free_joint_qpos(self): + def test_set_socket_peg_task_state_writes_free_joint_qpos(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - setter = getattr(air_insert_env, "set_ring_bar_task_state", None) + setter = getattr(air_insert_env, "set_socket_peg_task_state", None) self.assertIsNotNone( setter, - "Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state", + "Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state", ) - ring_qpos = np.zeros(7, dtype=np.float64) - bar_qpos = np.zeros(7, dtype=np.float64) + socket_qpos = np.zeros(7, dtype=np.float64) + peg_qpos = np.zeros(7, dtype=np.float64) class _FakeJoint: def __init__(self, qpos): @@ -115,40 +130,40 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase): class _FakeData: def joint(self, name): - if name == "ring_block_joint": - return _FakeJoint(ring_qpos) - if name == "bar_block_joint": - return _FakeJoint(bar_qpos) + if name == "blue_socket_joint": + return _FakeJoint(socket_qpos) + if name == "red_peg_joint": + return _FakeJoint(peg_qpos) raise AssertionError(f"Unexpected joint name: {name}") task_state = { - "ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64), - "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), - "bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64), - "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), + "socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64), + "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), + "peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64), + "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), } setter(_FakeData(), task_state) np.testing.assert_array_equal( - ring_qpos, - np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + socket_qpos, + np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), ) np.testing.assert_array_equal( - bar_qpos, - np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + peg_qpos, + np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), ) - def test_get_ring_bar_env_state_returns_stable_14d_vector(self): + def test_get_socket_peg_env_state_returns_stable_14d_vector(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - getter = getattr(air_insert_env, "get_ring_bar_env_state", None) + getter = getattr(air_insert_env, "get_socket_peg_env_state", None) self.assertIsNotNone( getter, - "Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state", + "Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state", ) - ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) - bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) + socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) + peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) class _FakeJoint: def __init__(self, qpos): @@ -156,10 +171,10 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase): class _FakeData: def joint(self, name): - if name == "ring_block_joint": - return _FakeJoint(ring_qpos) - if name == "bar_block_joint": - return _FakeJoint(bar_qpos) + if name == "blue_socket_joint": + return _FakeJoint(socket_qpos) + if name == "red_peg_joint": + return _FakeJoint(peg_qpos) raise AssertionError(f"Unexpected joint name: {name}") env_state = getter(_FakeData()) @@ -168,38 +183,78 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase): np.testing.assert_array_equal( env_state, np.array( - [-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], + [-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64, ), ) + def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) + self.assertIsNotNone(env_cls) + + source = inspect.getsource(env_cls) + + self.assertNotIn("_update_scripted_grasped_objects", source) + self.assertNotIn("_scripted_", source) + self.assertNotIn("_stabilize_ring_grasp", source) + self.assertNotIn("_ring_grasp_locked", source) + get_reward_source = inspect.getsource(env_cls._get_reward) + self.assertNotIn("ring_block", get_reward_source) + self.assertNotIn("bar_block", get_reward_source) + + def test_socket_peg_xml_defines_active_socket_and_peg_objects(self): + asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed" + xml_path = asset_dir / "socket_peg_objects.xml" + self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml") + self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed") + + root = ET.parse(xml_path).getroot() + body_names = {body.attrib.get("name") for body in root.findall(".//body")} + geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")} + joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")} + + self.assertIn("socket", body_names) + self.assertIn("peg", body_names) + self.assertNotIn("ring_block", body_names) + self.assertNotIn("bar_block", body_names) + self.assertIn("blue_socket_joint", joint_names) + self.assertIn("red_peg_joint", joint_names) + for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"): + self.assertIn(geom_name, geom_names) + + def test_socket_peg_wrapper_includes_socket_peg_objects(self): + xml_path = ( + pathlib.Path(__file__).resolve().parents[1] + / "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml" + ) + self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name") + root = ET.parse(xml_path).getroot() + includes = [include.attrib.get("file") for include in root.findall(".//include")] + self.assertIn("./socket_peg_objects.xml", includes) + self.assertNotIn("./ring_bar_objects.xml", includes) + class AirInsertRewardAndSuccessTest(unittest.TestCase): @staticmethod def _make_env_state( - ring_pos=(0.0, 0.0, 0.50), - ring_quat=(1.0, 0.0, 0.0, 0.0), - bar_pos=(0.0, 0.0, 0.50), - bar_quat=(0.70710678, 0.0, 0.70710678, 0.0), + socket_pos=(0.0, 0.0, 0.472), + socket_quat=(1.0, 0.0, 0.0, 0.0), + peg_pos=(0.0, 0.0, 0.46), + peg_quat=(1.0, 0.0, 0.0, 0.0), ): - return np.array( - [*ring_pos, *ring_quat, *bar_pos, *bar_quat], - dtype=np.float64, - ) + return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64) def test_compute_air_insert_reward_counts_left_contact_stage(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) - self.assertIsNotNone( - reward_fn, - "Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward", - ) + self.assertIsNotNone(reward_fn) reward = reward_fn( contact_pairs=[ - ("ring_block_north", "l_finger_left"), - ("ring_block_north", "table"), - ("bar_block", "table"), + ("socket-1", "l_finger_left"), + ("socket-1", "table"), + ("red_peg", "table"), ], env_state=self._make_env_state(), ) @@ -212,10 +267,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): reward = reward_fn( contact_pairs=[ - ("ring_block_north", "l_finger_left"), - ("bar_block", "l_finger_right"), - ("ring_block_north", "table"), - ("bar_block", "table"), + ("socket-1", "l_finger_left"), + ("red_peg", "l_finger_right"), + ("socket-1", "table"), + ("red_peg", "table"), ], env_state=self._make_env_state(), ) @@ -228,47 +283,43 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): reward = reward_fn( contact_pairs=[ - ("ring_block_north", "l_finger_left"), - ("bar_block", "l_finger_right"), + ("socket-1", "l_finger_left"), + ("red_peg", "l_finger_right"), ], - env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)), + env_state=self._make_env_state(), ) self.assertEqual(reward, 4) - def test_bar_fully_inserted_through_ring_accepts_true_positive(self): + def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + + reward = reward_fn( + contact_pairs=[ + ("socket-3", "r_fingertip_g0_vis_left"), + ("red_peg", "l_fingertip_g0_vis_right"), + ], + env_state=self._make_env_state(), + ) + + self.assertEqual( + reward, + 4, + "visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts", + ) + + def test_peg_inserted_into_socket_uses_pin_contact(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None) self.assertIsNotNone( success_fn, - "Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring", + "Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket", ) - self.assertTrue( - success_fn( - self._make_env_state(), - ) - ) - - def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self): - air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) - - self.assertFalse( - success_fn( - self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)), - ) - ) - - def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self): - air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") - success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) - - self.assertFalse( - success_fn( - self._make_env_state(bar_pos=(0.0, 0.0, 0.56)), - ) - ) + self.assertTrue(success_fn([("red_peg", "pin")])) + self.assertTrue(success_fn([("pin", "red_peg")])) + self.assertFalse(success_fn([("red_peg", "socket-1")])) def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") @@ -276,9 +327,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): reward = reward_fn( contact_pairs=[ - ("ring_block_north", "l_finger_left"), - ("bar_block", "l_finger_right"), - ("ring_block_north", "table"), + ("socket-1", "l_finger_left"), + ("red_peg", "l_finger_right"), + ("socket-1", "table"), + ("red_peg", "pin"), ], env_state=self._make_env_state(), ) @@ -291,8 +343,9 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): reward = reward_fn( contact_pairs=[ - ("ring_block_north", "l_finger_left"), - ("bar_block", "l_finger_right"), + ("socket-1", "l_finger_left"), + ("red_peg", "l_finger_right"), + ("red_peg", "pin"), ], env_state=self._make_env_state(), ) @@ -301,41 +354,129 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase): class AirInsertPolicyAndSmokeTest(unittest.TestCase): + @staticmethod + def _canonical_task_state(): + return { + "socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32), + "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32), + "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + def test_air_insert_policy_emits_valid_16d_action(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) - self.assertIsNotNone( - policy_cls, - "Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy", - ) + self.assertIsNotNone(policy_cls) - task_state = act_ex_utils.sample_air_insert_ring_bar_state() + task_state = act_ex_utils.sample_air_insert_socket_peg_state() policy = policy_cls(inject_noise=False) action = policy.predict(task_state, 0) self.assertEqual(action.shape, (16,)) np.testing.assert_array_equal(action[-2:], np.array([100, 100])) - def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self): + def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = self._canonical_task_state() + policy = policy_cls(inject_noise=False) + policy.generate_trajectory(task_state) + + start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T) + end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T) + + self.assertLess( + end_waypoint["xyz"][0], + start_waypoint["xyz"][0] - 0.10, + "front-view right-to-left peg insertion should decrease world x substantially", + ) + self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02) + expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168) + self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02) + self.assertGreater(float(start_waypoint["xyz"][2]), 0.70) + + def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = { + "socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32), + "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32), + "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + + policy = policy_cls(inject_noise=False) + policy.generate_trajectory(task_state) + left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180) + right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180) + action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11) + expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset]) + expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01]) + + np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6) + np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6) + self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket") + self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg") + + def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + base_state = { + "socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32), + "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32), + "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + shifted_state = dict(base_state) + shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32) + + base_policy = policy_cls(inject_noise=False) + base_policy.generate_trajectory(base_state) + shifted_policy = policy_cls(inject_noise=False) + shifted_policy.generate_trajectory(shifted_state) + + base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450) + shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450) + np.testing.assert_allclose( + base_hold["xyz"][:2], + base_state["socket_pos"][:2] + np.array([-0.078, 0.0]), + atol=1e-6, + ) + np.testing.assert_allclose( + shifted_hold["xyz"][:2], + shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]), + atol=1e-6, + ) + + def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = self._canonical_task_state() + policy = policy_cls(inject_noise=False) + + for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]): + action = policy.predict(task_state, step) + self.assertEqual(action.shape, (16,)) + + def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self): rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes") sampler_fn = getattr(rollout_module, "sample_task_state", None) policy_factory = getattr(rollout_module, "make_policy", None) - self.assertIsNotNone( - sampler_fn, - "Expected roboimi.demos.diana_record_sim_episodes.sample_task_state", - ) - self.assertIsNotNone( - policy_factory, - "Expected roboimi.demos.diana_record_sim_episodes.make_policy", - ) + self.assertIsNotNone(sampler_fn) + self.assertIsNotNone(policy_factory) - task_state = sampler_fn("sim_air_insert_ring_bar") - self.assertEqual( - list(task_state.keys()), - ["ring_pos", "ring_quat", "bar_pos", "bar_quat"], - ) + task_state = sampler_fn(TASK_NAME) + self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"]) - policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False) + policy = policy_factory(TASK_NAME, inject_noise=False) self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy") def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self): @@ -343,8 +484,8 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase): policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) - task_state = act_ex_utils.sample_air_insert_ring_bar_state() - env = make_sim_env("sim_air_insert_ring_bar", headless=True) + task_state = act_ex_utils.sample_air_insert_socket_peg_state() + env = make_sim_env(TASK_NAME, headless=True) policy = policy_cls(inject_noise=False) try: @@ -363,115 +504,6 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase): if viewer is not None: viewer.close() - def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self): - policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") - policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) - self.assertIsNotNone(policy_cls) - - task_state = { - "ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32), - "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - "bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32), - "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - } - - env = make_sim_env("sim_air_insert_ring_bar", headless=True) - policy = policy_cls(inject_noise=False) - - def is_cross_arm_pair(a, b): - return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b) - - try: - env.reset(task_state) - for step in range(460): - action = policy.predict(task_state, step) - env.step(action) - pairs = [] - for i in range(env.mj_data.ncon): - geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1) - geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2) - if geom1 and geom2 and is_cross_arm_pair(geom1, geom2): - pairs.append((geom1, geom2)) - self.assertFalse( - pairs, - f"cross-arm contact detected at step {step}: {pairs[:5]}", - ) - finally: - env.exit_flag = True - cam_thread = getattr(env, "cam_thread", None) - if cam_thread is not None: - cam_thread.join(timeout=1.0) - viewer = getattr(env, "viewer", None) - if viewer is not None: - viewer.close() - - def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self): - policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") - policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) - self.assertIsNotNone(policy_cls) - - task_state = { - "ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32), - "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - "bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32), - "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - } - - env = make_sim_env("sim_air_insert_ring_bar", headless=True) - policy = policy_cls(inject_noise=False) - - try: - env.reset(task_state) - for step in range(400): - action = policy.predict(task_state, step) - env.step(action) - ring_z = float(env.get_env_state()[2]) - self.assertGreater( - ring_z, - 0.55, - f"ring dropped before hold phase completed, final z={ring_z:.4f}", - ) - finally: - env.exit_flag = True - cam_thread = getattr(env, "cam_thread", None) - if cam_thread is not None: - cam_thread.join(timeout=1.0) - viewer = getattr(env, "viewer", None) - if viewer is not None: - viewer.close() - - def test_scripted_policy_reaches_max_reward_on_canonical_case(self): - policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") - policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) - self.assertIsNotNone(policy_cls) - - task_state = { - "ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32), - "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - "bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32), - "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - } - - env = make_sim_env("sim_air_insert_ring_bar", headless=True) - policy = policy_cls(inject_noise=False) - max_reward = float("-inf") - - try: - env.reset(task_state) - for step in range(700): - action = policy.predict(task_state, step) - env.step(action) - max_reward = max(max_reward, float(env.rew)) - self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}") - finally: - env.exit_flag = True - cam_thread = getattr(env, "cam_thread", None) - if cam_thread is not None: - cam_thread.join(timeout=1.0) - viewer = getattr(env, "viewer", None) - if viewer is not None: - viewer.close() - if __name__ == "__main__": unittest.main() diff --git a/tests/test_eval_vla_headless.py b/tests/test_eval_vla_headless.py index da11bd2..befccfa 100644 --- a/tests/test_eval_vla_headless.py +++ b/tests/test_eval_vla_headless.py @@ -114,7 +114,7 @@ class EvalVLAHeadlessTest(unittest.TestCase): is_render=False, control_freq=30, is_interpolate=True, - cam_view="angle", + cam_view="left_side", ) def test_camera_viewer_headless_updates_images_without_gui_calls(self): @@ -123,11 +123,11 @@ class EvalVLAHeadlessTest(unittest.TestCase): env.mj_data = object() env.exit_flag = False env.is_render = False - env.cam = "angle" + env.cam = "left_side" env.r_vis = None env.l_vis = None env.top = None - env.angle = None + env.left_side = None env.front = None with mock.patch( @@ -144,7 +144,7 @@ class EvalVLAHeadlessTest(unittest.TestCase): self.assertIsNotNone(env.r_vis) self.assertIsNotNone(env.l_vis) self.assertIsNotNone(env.top) - self.assertIsNotNone(env.angle) + self.assertIsNotNone(env.left_side) self.assertIsNotNone(env.front) def test_eval_main_headless_skips_render_and_still_executes_policy(self): @@ -254,19 +254,19 @@ class EvalVLAHeadlessTest(unittest.TestCase): self.assertAlmostEqual(summary["avg_reward"], 3.75) self.assertEqual(summary["num_episodes"], 2) - def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self): + def test_run_eval_uses_air_insert_sampler_for_socket_peg_task(self): self.assertTrue( - hasattr(eval_vla, "sample_air_insert_ring_bar_state"), - "Expected eval_vla to expose the new ring/bar reset sampler", + hasattr(eval_vla, "sample_air_insert_socket_peg_state"), + "Expected eval_vla to expose the new socket/peg reset sampler", ) fake_env = _FakeEnv() fake_agent = _FakeAgent() sampled_task_state = { - "ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32), - "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), - "bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32), - "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "socket_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32), + "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "peg_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32), + "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), } cfg = OmegaConf.create( { @@ -276,7 +276,7 @@ class EvalVLAHeadlessTest(unittest.TestCase): "num_episodes": 1, "max_timesteps": 1, "device": "cpu", - "task_name": "sim_air_insert_ring_bar", + "task_name": "sim_air_insert_socket_peg", "camera_names": ["front"], "use_smoothing": False, "smooth_alpha": 0.3, @@ -296,12 +296,12 @@ class EvalVLAHeadlessTest(unittest.TestCase): return_value=fake_env, ) as make_env, mock.patch.object( eval_vla, - "sample_air_insert_ring_bar_state", + "sample_air_insert_socket_peg_state", return_value=sampled_task_state, - ) as ring_bar_sampler, mock.patch.object( + ) as socket_peg_sampler, mock.patch.object( eval_vla, "sample_transfer_pose", - side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"), + side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_socket_peg"), ), mock.patch.object( eval_vla, "execute_policy_action", @@ -312,8 +312,8 @@ class EvalVLAHeadlessTest(unittest.TestCase): ): eval_vla._run_eval(cfg) - make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True) - ring_bar_sampler.assert_called_once_with() + make_env.assert_called_once_with("sim_air_insert_socket_peg", headless=True) + socket_peg_sampler.assert_called_once_with() execute_policy_action.assert_called_once() self.assertEqual(fake_env.reset_calls, [sampled_task_state]) diff --git a/tests/test_robot_asset_paths.py b/tests/test_robot_asset_paths.py index 0a1e5de..5c2fd08 100644 --- a/tests/test_robot_asset_paths.py +++ b/tests/test_robot_asset_paths.py @@ -59,15 +59,15 @@ class RobotAssetPathResolutionTest(unittest.TestCase): self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf}) self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls)) - def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self): - BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None) + def test_bidianamed_socket_peg_resolves_robot_asset_paths_independent_of_cwd(self): + BiDianaMedSocketPeg = getattr(diana_med, 'BiDianaMedSocketPeg', None) self.assertIsNotNone( - BiDianaMedRingBar, - 'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar', + BiDianaMedSocketPeg, + 'Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg', ) repo_root = Path(__file__).resolve().parents[1] - expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml' + expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml' expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf' xml_calls = [] @@ -89,7 +89,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase): 'roboimi.assets.robots.arm_base.KDL_utils', _FakeKDL, ): - BiDianaMedRingBar() + BiDianaMedSocketPeg() finally: os.chdir(previous_cwd)