fix(policy): avoid cross-arm collision in air insert rollout

This commit is contained in:
Logic
2026-04-23 18:04:54 +08:00
parent 8145c9eb62
commit d245d64def
2 changed files with 51 additions and 7 deletions

View File

@@ -42,9 +42,10 @@ class TestAirInsertPolicy(PolicyBase):
right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64) meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
left_hold_xyz = meet_xyz + np.array([-0.02, 0.0, 0.08], dtype=np.float64) left_hold_xyz = meet_xyz + np.array([-0.16, 0.06, 0.14], dtype=np.float64)
right_insert_start_xyz = meet_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64) right_wait_xyz = meet_xyz + np.array([0.24, -0.08, 0.18], dtype=np.float64)
right_insert_end_xyz = meet_xyz + np.array([0.0, 0.0, -0.02], dtype=np.float64) right_insert_start_xyz = meet_xyz + np.array([0.08, -0.02, 0.14], dtype=np.float64)
right_insert_end_xyz = meet_xyz + np.array([0.02, 0.02, 0.10], dtype=np.float64)
self.left_trajectory = [ self.left_trajectory = [
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100}, {"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
@@ -52,8 +53,8 @@ class TestAirInsertPolicy(PolicyBase):
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100}, {"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100}, {"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100}, {"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
{"t": 420, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100}, {"t": 360, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
{"t": 700, "xyz": left_hold_xyz, "quat": init_mocap_pose_left[3:], "gripper": -100}, {"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
] ]
self.right_trajectory = [ self.right_trajectory = [
@@ -62,7 +63,8 @@ class TestAirInsertPolicy(PolicyBase):
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100}, {"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100}, {"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
{"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100}, {"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100},
{"t": 420, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 420, "xyz": right_wait_xyz, "quat": right_pick_quat, "gripper": -100},
{"t": 580, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 560, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 640, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100}, {"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
] ]

View File

@@ -363,6 +363,48 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
if viewer is not None: if viewer is not None:
viewer.close() viewer.close()
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
def is_cross_arm_pair(a, b):
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
try:
env.reset(task_state)
for step in range(460):
action = policy.predict(task_state, step)
env.step(action)
pairs = []
for i in range(env.mj_data.ncon):
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
pairs.append((geom1, geom2))
self.assertFalse(
pairs,
f"cross-arm contact detected at step {step}: {pairs[:5]}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()