From 4c3646a3d56060655144b57e17e9dd24be7c9eb9 Mon Sep 17 00:00:00 2001 From: Logic Date: Fri, 24 Apr 2026 09:41:37 +0800 Subject: [PATCH] fix(policy): perform stable horizontal air insertion --- roboimi/demos/diana_air_insert_policy.py | 51 ++++++++++--- roboimi/envs/double_air_insert_env.py | 94 +++++++++++++++++++++--- 2 files changed, 122 insertions(+), 23 deletions(-) diff --git a/roboimi/demos/diana_air_insert_policy.py b/roboimi/demos/diana_air_insert_policy.py index 7a6492c..30511bb 100644 --- a/roboimi/demos/diana_air_insert_policy.py +++ b/roboimi/demos/diana_air_insert_policy.py @@ -5,6 +5,13 @@ from roboimi.demos.diana_policy import PolicyBase class TestAirInsertPolicy(PolicyBase): + @staticmethod + def _action_xyz_for_object_center(object_center, ee_quat, object_offset_local): + return ( + np.asarray(object_center, dtype=np.float64) + - np.asarray(Quaternion(ee_quat).rotate(object_offset_local), dtype=np.float64) + ) + def generate_trajectory(self, task_state): ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64) bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64) @@ -37,30 +44,52 @@ class TestAirInsertPolicy(PolicyBase): left_init_quat = Quaternion(init_mocap_pose_left[3:]) right_init_quat = Quaternion(init_mocap_pose_right[3:]) + object_offset_local = np.array([0.0, 0.0, -0.09], dtype=np.float64) left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements + left_hold_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=-90).elements right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements + insert_quat_local = Quaternion([-0.50019721, 0.50020088, 0.49980484, 0.49979692]) right_insert_quat = np.array( - [-0.50019721, 0.50020088, 0.49980484, 0.49979692], + (Quaternion(left_hold_quat) * insert_quat_local).elements, dtype=np.float64, ) meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64) - left_stabilize_xyz = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64) - left_hold_xyz = meet_xyz + np.array([-0.18, 0.10, -0.08], dtype=np.float64) - right_reorient_xyz = bar_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64) - right_wait_xyz = left_hold_xyz + np.array([0.14, 0.16, -0.04], dtype=np.float64) - right_insert_start_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.08], dtype=np.float64) - right_insert_end_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.0], dtype=np.float64) + ring_stabilize_center = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64) + ring_hold_center = meet_xyz + np.array([-0.10, 0.05, -0.16], dtype=np.float64) + bar_reorient_center = bar_xyz + np.array([0.0, 0.0, 0.16], dtype=np.float64) + bar_wait_center = ring_hold_center + np.array([0.05, -0.18, 0.0], dtype=np.float64) + bar_insert_start_center = ring_hold_center + np.array([0.0, -0.075, 0.0], dtype=np.float64) + bar_insert_end_center = ring_hold_center + np.array([0.0, 0.075, 0.0], dtype=np.float64) + + left_stabilize_xyz = self._action_xyz_for_object_center( + ring_stabilize_center, left_pick_quat, object_offset_local + ) + left_hold_xyz = self._action_xyz_for_object_center( + ring_hold_center, left_hold_quat, object_offset_local + ) + right_reorient_xyz = self._action_xyz_for_object_center( + bar_reorient_center, right_insert_quat, object_offset_local + ) + right_wait_xyz = self._action_xyz_for_object_center( + bar_wait_center, right_insert_quat, object_offset_local + ) + right_insert_start_xyz = self._action_xyz_for_object_center( + bar_insert_start_center, right_insert_quat, object_offset_local + ) + right_insert_end_xyz = self._action_xyz_for_object_center( + bar_insert_end_center, right_insert_quat, object_offset_local + ) self.left_trajectory = [ {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, {"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100}, {"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, {"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, - {"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100}, + {"t": 260, "xyz": self._action_xyz_for_object_center(ring_xyz + np.array([0.0, 0.0, 0.24]), left_pick_quat, object_offset_local), "quat": left_pick_quat, "gripper": -100}, {"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100}, - {"t": 460, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100}, - {"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100}, + {"t": 460, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, + {"t": 700, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100}, ] self.right_trajectory = [ @@ -68,7 +97,7 @@ class TestAirInsertPolicy(PolicyBase): {"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100}, {"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, {"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, - {"t": 240, "xyz": bar_xyz + np.array([0.0, 0.0, 0.12]), "quat": right_pick_quat, "gripper": -100}, + {"t": 240, "xyz": self._action_xyz_for_object_center(bar_xyz + np.array([0.0, 0.0, 0.12]), right_pick_quat, object_offset_local), "quat": right_pick_quat, "gripper": -100}, {"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py index a51c7b1..1050fdf 100644 --- a/roboimi/envs/double_air_insert_env.py +++ b/roboimi/envs/double_air_insert_env.py @@ -38,8 +38,6 @@ TABLE_GEOM_NAME = "table" RING_APERTURE_HALF_WIDTH = 0.016 RING_HALF_THICKNESS = 0.009 BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64) -SCRIPTED_RING_GRASP_OFFSET = np.array([0.12, 0.022, -0.09], dtype=np.float64) -SCRIPTED_BAR_GRASP_OFFSET = np.array([-0.045, 0.0, -0.09], dtype=np.float64) SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0 @@ -103,6 +101,28 @@ def _quat_to_rotation_matrix(quat): ) +def _quat_multiply(lhs, rhs): + lhs = np.asarray(lhs, dtype=np.float64) + rhs = np.asarray(rhs, dtype=np.float64) + w1, x1, y1, z1 = lhs + w2, x2, y2, z2 = rhs + return np.array( + [ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + ], + dtype=np.float64, + ) + + +def _quat_inverse(quat): + quat = np.asarray(quat, dtype=np.float64) + norm_sq = float(np.dot(quat, quat)) + return np.array([quat[0], -quat[1], -quat[2], -quat[3]], dtype=np.float64) / norm_sq + + def _split_env_state(env_state): env_state = np.asarray(env_state, dtype=np.float64) if env_state.shape != (14,): @@ -163,11 +183,19 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): self.max_reward = 5 self._scripted_ring_grasped = False self._scripted_bar_grasped = False + self._scripted_ring_pos_offset_local = None + self._scripted_bar_pos_offset_local = None + self._scripted_ring_quat_offset = None + self._scripted_bar_quat_offset = None self._air_insert_step_count = 0 def reset(self, task_state): self._scripted_ring_grasped = False self._scripted_bar_grasped = False + self._scripted_ring_pos_offset_local = None + self._scripted_bar_pos_offset_local = None + self._scripted_ring_quat_offset = None + self._scripted_bar_quat_offset = None self._air_insert_step_count = 0 set_ring_bar_task_state(self.mj_data, task_state) DualDianaMed.reset(self) @@ -195,26 +223,68 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): self._air_insert_step_count += 1 def _update_scripted_grasped_objects(self, action): - if action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180: + if ( + action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD + and self._air_insert_step_count >= 180 + and not self._scripted_ring_grasped + ): self._scripted_ring_grasped = True - if action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180: + self._attach_scripted_object( + object_joint_name=RING_JOINT_NAME, + ee_pos=action[:3], + ee_quat=action[3:7], + pos_attr="_scripted_ring_pos_offset_local", + quat_attr="_scripted_ring_quat_offset", + ) + if ( + action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD + and self._air_insert_step_count >= 180 + and not self._scripted_bar_grasped + ): self._scripted_bar_grasped = True + self._attach_scripted_object( + object_joint_name=BAR_JOINT_NAME, + ee_pos=action[7:10], + ee_quat=action[10:14], + pos_attr="_scripted_bar_pos_offset_local", + quat_attr="_scripted_bar_quat_offset", + ) if self._scripted_ring_grasped: - _set_free_joint_pose( - self.mj_data.joint(RING_JOINT_NAME), - np.asarray(action[:3], dtype=np.float64) + SCRIPTED_RING_GRASP_OFFSET, - action[3:7], + self._update_scripted_object_pose( + object_joint_name=RING_JOINT_NAME, + ee_pos=action[:3], + ee_quat=action[3:7], + pos_offset_local=self._scripted_ring_pos_offset_local, + quat_offset=self._scripted_ring_quat_offset, ) if self._scripted_bar_grasped: - _set_free_joint_pose( - self.mj_data.joint(BAR_JOINT_NAME), - np.asarray(action[7:10], dtype=np.float64) + SCRIPTED_BAR_GRASP_OFFSET, - action[10:14], + self._update_scripted_object_pose( + object_joint_name=BAR_JOINT_NAME, + ee_pos=action[7:10], + ee_quat=action[10:14], + pos_offset_local=self._scripted_bar_pos_offset_local, + quat_offset=self._scripted_bar_quat_offset, ) if self._scripted_ring_grasped or self._scripted_bar_grasped: mj.mj_forward(self.mj_model, self.mj_data) + def _attach_scripted_object(self, object_joint_name, ee_pos, ee_quat, pos_attr, quat_attr): + ee_pos = np.asarray(ee_pos, dtype=np.float64) + ee_quat = np.asarray(ee_quat, dtype=np.float64) + object_qpos = np.asarray(self.mj_data.joint(object_joint_name).qpos[:7], dtype=np.float64) + ee_rot = _quat_to_rotation_matrix(ee_quat) + setattr(self, pos_attr, ee_rot.T @ (object_qpos[:3] - ee_pos)) + setattr(self, quat_attr, _quat_multiply(_quat_inverse(ee_quat), object_qpos[3:7])) + + def _update_scripted_object_pose(self, object_joint_name, ee_pos, ee_quat, pos_offset_local, quat_offset): + ee_pos = np.asarray(ee_pos, dtype=np.float64) + ee_quat = np.asarray(ee_quat, dtype=np.float64) + ee_rot = _quat_to_rotation_matrix(ee_quat) + object_pos = ee_pos + ee_rot @ np.asarray(pos_offset_local, dtype=np.float64) + object_quat = _quat_multiply(ee_quat, quat_offset) + _set_free_joint_pose(self.mj_data.joint(object_joint_name), object_pos, object_quat) + def get_env_state(self): return get_ring_bar_env_state(self.mj_data)