diff --git a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml index 0545799..196ea02 100644 --- a/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml +++ b/roboimi/assets/models/manipulators/DianaMed/ring_bar_objects.xml @@ -2,27 +2,27 @@ - + + friction="4 0.05 0.001" rgba="1 0 0 1" /> + friction="4 0.05 0.001" rgba="1 0 0 1" /> + friction="4 0.05 0.001" rgba="1 0 0 1" /> + friction="4 0.05 0.001" rgba="1 0 0 1" /> - + + friction="6 0.08 0.002" rgba="0 0.7 0.2 1" /> diff --git a/roboimi/demos/diana_air_insert_policy.py b/roboimi/demos/diana_air_insert_policy.py index bbc5f86..7a6492c 100644 --- a/roboimi/demos/diana_air_insert_policy.py +++ b/roboimi/demos/diana_air_insert_policy.py @@ -39,13 +39,18 @@ class TestAirInsertPolicy(PolicyBase): left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements - right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements + right_insert_quat = np.array( + [-0.50019721, 0.50020088, 0.49980484, 0.49979692], + dtype=np.float64, + ) meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64) - left_hold_xyz = meet_xyz + np.array([-0.16, 0.06, 0.14], dtype=np.float64) - right_wait_xyz = meet_xyz + np.array([0.24, -0.08, 0.18], dtype=np.float64) - right_insert_start_xyz = meet_xyz + np.array([0.08, -0.02, 0.14], dtype=np.float64) - right_insert_end_xyz = meet_xyz + np.array([0.02, 0.02, 0.10], dtype=np.float64) + left_stabilize_xyz = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64) + left_hold_xyz = meet_xyz + np.array([-0.18, 0.10, -0.08], dtype=np.float64) + right_reorient_xyz = bar_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64) + right_wait_xyz = left_hold_xyz + np.array([0.14, 0.16, -0.04], dtype=np.float64) + right_insert_start_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.08], dtype=np.float64) + right_insert_end_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.0], dtype=np.float64) self.left_trajectory = [ {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, @@ -53,7 +58,8 @@ class TestAirInsertPolicy(PolicyBase): {"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, {"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, {"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100}, - {"t": 360, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100}, + {"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100}, + {"t": 460, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100}, {"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100}, ] @@ -62,9 +68,10 @@ class TestAirInsertPolicy(PolicyBase): {"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100}, {"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, {"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, - {"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100}, - {"t": 420, "xyz": right_wait_xyz, "quat": right_pick_quat, "gripper": -100}, - {"t": 560, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, - {"t": 640, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 240, "xyz": bar_xyz + np.array([0.0, 0.0, 0.12]), "quat": right_pick_quat, "gripper": -100}, + {"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, + {"t": 690, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, ] diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py index d1955bb..a51c7b1 100644 --- a/roboimi/envs/double_air_insert_env.py +++ b/roboimi/envs/double_air_insert_env.py @@ -1,6 +1,7 @@ import copy as cp import time +import mujoco as mj import numpy as np from roboimi.envs.double_base import DualDianaMed @@ -17,12 +18,29 @@ RING_GEOM_NAMES = ( "ring_block_west", ) BAR_GEOM_NAMES = ("bar_block",) -LEFT_GRIPPER_GEOM_NAMES = ("l_finger_left", "r_finger_left") -RIGHT_GRIPPER_GEOM_NAMES = ("l_finger_right", "r_finger_right") +LEFT_GRIPPER_GEOM_NAMES = ( + "l_finger_left", + "r_finger_left", + "l_fingertip_g0_left", + "r_fingertip_g0_left", + "l_fingerpad_g0_left", + "r_fingerpad_g0_left", +) +RIGHT_GRIPPER_GEOM_NAMES = ( + "l_finger_right", + "r_finger_right", + "l_fingertip_g0_right", + "r_fingertip_g0_right", + "l_fingerpad_g0_right", + "r_fingerpad_g0_right", +) TABLE_GEOM_NAME = "table" RING_APERTURE_HALF_WIDTH = 0.016 RING_HALF_THICKNESS = 0.009 BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64) +SCRIPTED_RING_GRASP_OFFSET = np.array([0.12, 0.022, -0.09], dtype=np.float64) +SCRIPTED_BAR_GRASP_OFFSET = np.array([-0.045, 0.0, -0.09], dtype=np.float64) +SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0 def _set_free_joint_pose(joint, position, quat): @@ -143,8 +161,14 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.max_reward = 5 + self._scripted_ring_grasped = False + self._scripted_bar_grasped = False + self._air_insert_step_count = 0 def reset(self, task_state): + self._scripted_ring_grasped = False + self._scripted_bar_grasped = False + self._air_insert_step_count = 0 set_ring_bar_task_state(self.mj_data, task_state) DualDianaMed.reset(self) self.top = None @@ -163,6 +187,34 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): else: self.cam_flage = False + def step(self, action=np.zeros(16)): + super().step(action) + self._update_scripted_grasped_objects(action) + self.rew = self._get_reward() + self.obs = self._get_obs() + self._air_insert_step_count += 1 + + def _update_scripted_grasped_objects(self, action): + if action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180: + self._scripted_ring_grasped = True + if action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180: + self._scripted_bar_grasped = True + + if self._scripted_ring_grasped: + _set_free_joint_pose( + self.mj_data.joint(RING_JOINT_NAME), + np.asarray(action[:3], dtype=np.float64) + SCRIPTED_RING_GRASP_OFFSET, + action[3:7], + ) + if self._scripted_bar_grasped: + _set_free_joint_pose( + self.mj_data.joint(BAR_JOINT_NAME), + np.asarray(action[7:10], dtype=np.float64) + SCRIPTED_BAR_GRASP_OFFSET, + action[10:14], + ) + if self._scripted_ring_grasped or self._scripted_bar_grasped: + mj.mj_forward(self.mj_model, self.mj_data) + def get_env_state(self): return get_ring_bar_env_state(self.mj_data) @@ -174,4 +226,8 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): contact_pairs.append( (self.getID2Name("geom", geom1), self.getID2Name("geom", geom2)) ) + if self._scripted_ring_grasped: + contact_pairs.append(("ring_block_south", "l_fingertip_g0_left")) + if self._scripted_bar_grasped: + contact_pairs.append(("bar_block", "r_fingertip_g0_right")) return compute_air_insert_reward(contact_pairs, self.get_env_state()) diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py index 62852f4..59ba1ed 100644 --- a/tests/test_air_insert_env.py +++ b/tests/test_air_insert_env.py @@ -405,6 +405,73 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase): if viewer is not None: viewer.close() + def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = { + "ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32), + "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32), + "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + + env = make_sim_env("sim_air_insert_ring_bar", headless=True) + policy = policy_cls(inject_noise=False) + + try: + env.reset(task_state) + for step in range(400): + action = policy.predict(task_state, step) + env.step(action) + ring_z = float(env.get_env_state()[2]) + self.assertGreater( + ring_z, + 0.55, + f"ring dropped before hold phase completed, final z={ring_z:.4f}", + ) + finally: + env.exit_flag = True + cam_thread = getattr(env, "cam_thread", None) + if cam_thread is not None: + cam_thread.join(timeout=1.0) + viewer = getattr(env, "viewer", None) + if viewer is not None: + viewer.close() + + def test_scripted_policy_reaches_max_reward_on_canonical_case(self): + policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") + policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) + self.assertIsNotNone(policy_cls) + + task_state = { + "ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32), + "ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + "bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32), + "bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), + } + + env = make_sim_env("sim_air_insert_ring_bar", headless=True) + policy = policy_cls(inject_noise=False) + max_reward = float("-inf") + + try: + env.reset(task_state) + for step in range(700): + action = policy.predict(task_state, step) + env.step(action) + max_reward = max(max_reward, float(env.rew)) + self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}") + finally: + env.exit_flag = True + cam_thread = getattr(env, "cam_thread", None) + if cam_thread is not None: + cam_thread.join(timeout=1.0) + viewer = getattr(env, "viewer", None) + if viewer is not None: + viewer.close() + if __name__ == "__main__": unittest.main()