fix(policy): perform stable horizontal air insertion

This commit is contained in:
Logic
2026-04-24 09:41:37 +08:00
parent 4936cf2635
commit 4c3646a3d5
2 changed files with 122 additions and 23 deletions

View File

@@ -5,6 +5,13 @@ from roboimi.demos.diana_policy import PolicyBase
class TestAirInsertPolicy(PolicyBase):
@staticmethod
def _action_xyz_for_object_center(object_center, ee_quat, object_offset_local):
return (
np.asarray(object_center, dtype=np.float64)
- np.asarray(Quaternion(ee_quat).rotate(object_offset_local), dtype=np.float64)
)
def generate_trajectory(self, task_state):
ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64)
bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64)
@@ -37,30 +44,52 @@ class TestAirInsertPolicy(PolicyBase):
left_init_quat = Quaternion(init_mocap_pose_left[3:])
right_init_quat = Quaternion(init_mocap_pose_right[3:])
object_offset_local = np.array([0.0, 0.0, -0.09], dtype=np.float64)
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
left_hold_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=-90).elements
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
insert_quat_local = Quaternion([-0.50019721, 0.50020088, 0.49980484, 0.49979692])
right_insert_quat = np.array(
[-0.50019721, 0.50020088, 0.49980484, 0.49979692],
(Quaternion(left_hold_quat) * insert_quat_local).elements,
dtype=np.float64,
)
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
left_stabilize_xyz = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64)
left_hold_xyz = meet_xyz + np.array([-0.18, 0.10, -0.08], dtype=np.float64)
right_reorient_xyz = bar_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64)
right_wait_xyz = left_hold_xyz + np.array([0.14, 0.16, -0.04], dtype=np.float64)
right_insert_start_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.08], dtype=np.float64)
right_insert_end_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.0], dtype=np.float64)
ring_stabilize_center = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64)
ring_hold_center = meet_xyz + np.array([-0.10, 0.05, -0.16], dtype=np.float64)
bar_reorient_center = bar_xyz + np.array([0.0, 0.0, 0.16], dtype=np.float64)
bar_wait_center = ring_hold_center + np.array([0.05, -0.18, 0.0], dtype=np.float64)
bar_insert_start_center = ring_hold_center + np.array([0.0, -0.075, 0.0], dtype=np.float64)
bar_insert_end_center = ring_hold_center + np.array([0.0, 0.075, 0.0], dtype=np.float64)
left_stabilize_xyz = self._action_xyz_for_object_center(
ring_stabilize_center, left_pick_quat, object_offset_local
)
left_hold_xyz = self._action_xyz_for_object_center(
ring_hold_center, left_hold_quat, object_offset_local
)
right_reorient_xyz = self._action_xyz_for_object_center(
bar_reorient_center, right_insert_quat, object_offset_local
)
right_wait_xyz = self._action_xyz_for_object_center(
bar_wait_center, right_insert_quat, object_offset_local
)
right_insert_start_xyz = self._action_xyz_for_object_center(
bar_insert_start_center, right_insert_quat, object_offset_local
)
right_insert_end_xyz = self._action_xyz_for_object_center(
bar_insert_end_center, right_insert_quat, object_offset_local
)
self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
{"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100},
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
{"t": 260, "xyz": self._action_xyz_for_object_center(ring_xyz + np.array([0.0, 0.0, 0.24]), left_pick_quat, object_offset_local), "quat": left_pick_quat, "gripper": -100},
{"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100},
{"t": 460, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
{"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
{"t": 460, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100},
{"t": 700, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100},
]
self.right_trajectory = [
@@ -68,7 +97,7 @@ class TestAirInsertPolicy(PolicyBase):
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
{"t": 240, "xyz": bar_xyz + np.array([0.0, 0.0, 0.12]), "quat": right_pick_quat, "gripper": -100},
{"t": 240, "xyz": self._action_xyz_for_object_center(bar_xyz + np.array([0.0, 0.0, 0.12]), right_pick_quat, object_offset_local), "quat": right_pick_quat, "gripper": -100},
{"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},

View File

@@ -38,8 +38,6 @@ TABLE_GEOM_NAME = "table"
RING_APERTURE_HALF_WIDTH = 0.016
RING_HALF_THICKNESS = 0.009
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
SCRIPTED_RING_GRASP_OFFSET = np.array([0.12, 0.022, -0.09], dtype=np.float64)
SCRIPTED_BAR_GRASP_OFFSET = np.array([-0.045, 0.0, -0.09], dtype=np.float64)
SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0
@@ -103,6 +101,28 @@ def _quat_to_rotation_matrix(quat):
)
def _quat_multiply(lhs, rhs):
lhs = np.asarray(lhs, dtype=np.float64)
rhs = np.asarray(rhs, dtype=np.float64)
w1, x1, y1, z1 = lhs
w2, x2, y2, z2 = rhs
return np.array(
[
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
],
dtype=np.float64,
)
def _quat_inverse(quat):
quat = np.asarray(quat, dtype=np.float64)
norm_sq = float(np.dot(quat, quat))
return np.array([quat[0], -quat[1], -quat[2], -quat[3]], dtype=np.float64) / norm_sq
def _split_env_state(env_state):
env_state = np.asarray(env_state, dtype=np.float64)
if env_state.shape != (14,):
@@ -163,11 +183,19 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
self.max_reward = 5
self._scripted_ring_grasped = False
self._scripted_bar_grasped = False
self._scripted_ring_pos_offset_local = None
self._scripted_bar_pos_offset_local = None
self._scripted_ring_quat_offset = None
self._scripted_bar_quat_offset = None
self._air_insert_step_count = 0
def reset(self, task_state):
self._scripted_ring_grasped = False
self._scripted_bar_grasped = False
self._scripted_ring_pos_offset_local = None
self._scripted_bar_pos_offset_local = None
self._scripted_ring_quat_offset = None
self._scripted_bar_quat_offset = None
self._air_insert_step_count = 0
set_ring_bar_task_state(self.mj_data, task_state)
DualDianaMed.reset(self)
@@ -195,26 +223,68 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
self._air_insert_step_count += 1
def _update_scripted_grasped_objects(self, action):
if action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180:
if (
action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD
and self._air_insert_step_count >= 180
and not self._scripted_ring_grasped
):
self._scripted_ring_grasped = True
if action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180:
self._attach_scripted_object(
object_joint_name=RING_JOINT_NAME,
ee_pos=action[:3],
ee_quat=action[3:7],
pos_attr="_scripted_ring_pos_offset_local",
quat_attr="_scripted_ring_quat_offset",
)
if (
action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD
and self._air_insert_step_count >= 180
and not self._scripted_bar_grasped
):
self._scripted_bar_grasped = True
self._attach_scripted_object(
object_joint_name=BAR_JOINT_NAME,
ee_pos=action[7:10],
ee_quat=action[10:14],
pos_attr="_scripted_bar_pos_offset_local",
quat_attr="_scripted_bar_quat_offset",
)
if self._scripted_ring_grasped:
_set_free_joint_pose(
self.mj_data.joint(RING_JOINT_NAME),
np.asarray(action[:3], dtype=np.float64) + SCRIPTED_RING_GRASP_OFFSET,
action[3:7],
self._update_scripted_object_pose(
object_joint_name=RING_JOINT_NAME,
ee_pos=action[:3],
ee_quat=action[3:7],
pos_offset_local=self._scripted_ring_pos_offset_local,
quat_offset=self._scripted_ring_quat_offset,
)
if self._scripted_bar_grasped:
_set_free_joint_pose(
self.mj_data.joint(BAR_JOINT_NAME),
np.asarray(action[7:10], dtype=np.float64) + SCRIPTED_BAR_GRASP_OFFSET,
action[10:14],
self._update_scripted_object_pose(
object_joint_name=BAR_JOINT_NAME,
ee_pos=action[7:10],
ee_quat=action[10:14],
pos_offset_local=self._scripted_bar_pos_offset_local,
quat_offset=self._scripted_bar_quat_offset,
)
if self._scripted_ring_grasped or self._scripted_bar_grasped:
mj.mj_forward(self.mj_model, self.mj_data)
def _attach_scripted_object(self, object_joint_name, ee_pos, ee_quat, pos_attr, quat_attr):
ee_pos = np.asarray(ee_pos, dtype=np.float64)
ee_quat = np.asarray(ee_quat, dtype=np.float64)
object_qpos = np.asarray(self.mj_data.joint(object_joint_name).qpos[:7], dtype=np.float64)
ee_rot = _quat_to_rotation_matrix(ee_quat)
setattr(self, pos_attr, ee_rot.T @ (object_qpos[:3] - ee_pos))
setattr(self, quat_attr, _quat_multiply(_quat_inverse(ee_quat), object_qpos[3:7]))
def _update_scripted_object_pose(self, object_joint_name, ee_pos, ee_quat, pos_offset_local, quat_offset):
ee_pos = np.asarray(ee_pos, dtype=np.float64)
ee_quat = np.asarray(ee_quat, dtype=np.float64)
ee_rot = _quat_to_rotation_matrix(ee_quat)
object_pos = ee_pos + ee_rot @ np.asarray(pos_offset_local, dtype=np.float64)
object_quat = _quat_multiply(ee_quat, quat_offset)
_set_free_joint_pose(self.mj_data.joint(object_joint_name), object_pos, object_quat)
def get_env_state(self):
return get_ring_bar_env_state(self.mj_data)