fix(policy): avoid cross-arm collision in air insert rollout
This commit is contained in:
@@ -363,6 +363,48 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
def is_cross_arm_pair(a, b):
|
||||
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
|
||||
|
||||
try:
|
||||
env.reset(task_state)
|
||||
for step in range(460):
|
||||
action = policy.predict(task_state, step)
|
||||
env.step(action)
|
||||
pairs = []
|
||||
for i in range(env.mj_data.ncon):
|
||||
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
|
||||
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
|
||||
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
|
||||
pairs.append((geom1, geom2))
|
||||
self.assertFalse(
|
||||
pairs,
|
||||
f"cross-arm contact detected at step {step}: {pairs[:5]}",
|
||||
)
|
||||
finally:
|
||||
env.exit_flag = True
|
||||
cam_thread = getattr(env, "cam_thread", None)
|
||||
if cam_thread is not None:
|
||||
cam_thread.join(timeout=1.0)
|
||||
viewer = getattr(env, "viewer", None)
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user