feat(policy): add scripted air insertion policy

This commit is contained in:
Logic
2026-04-23 17:44:53 +08:00
parent a837a982f7
commit 8145c9eb62
3 changed files with 163 additions and 13 deletions

View File

@@ -300,5 +300,69 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
self.assertEqual(reward, 5)
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
def test_air_insert_policy_emits_valid_16d_action(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(
policy_cls,
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
policy = policy_cls(inject_noise=False)
action = policy.predict(task_state, 0)
self.assertEqual(action.shape, (16,))
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
sampler_fn = getattr(rollout_module, "sample_task_state", None)
policy_factory = getattr(rollout_module, "make_policy", None)
self.assertIsNotNone(
sampler_fn,
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
)
self.assertIsNotNone(
policy_factory,
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
)
task_state = sampler_fn("sim_air_insert_ring_bar")
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
action = policy.predict(task_state, 0)
env.step(action)
self.assertIsNotNone(env.obs)
self.assertIn("qpos", env.obs)
self.assertIn("images", env.obs)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__":
unittest.main()