fix(policy): stabilize air insert scripted success

This commit is contained in:
Logic
2026-04-24 09:20:50 +08:00
parent d245d64def
commit 4936cf2635
4 changed files with 149 additions and 19 deletions

View File

@@ -405,6 +405,73 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
if viewer is not None:
viewer.close()
def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
for step in range(400):
action = policy.predict(task_state, step)
env.step(action)
ring_z = float(env.get_env_state()[2])
self.assertGreater(
ring_z,
0.55,
f"ring dropped before hold phase completed, final z={ring_z:.4f}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_reaches_max_reward_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
max_reward = float("-inf")
try:
env.reset(task_state)
for step in range(700):
action = policy.predict(task_state, step)
env.step(action)
max_reward = max(max_reward, float(env.rew))
self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}")
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__":
unittest.main()