feat(policy): add scripted air insertion policy
This commit is contained in:
@@ -300,5 +300,69 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
self.assertEqual(reward, 5)
|
||||
|
||||
|
||||
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||
def test_air_insert_policy_emits_valid_16d_action(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(
|
||||
policy_cls,
|
||||
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
|
||||
)
|
||||
|
||||
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||
policy = policy_cls(inject_noise=False)
|
||||
action = policy.predict(task_state, 0)
|
||||
|
||||
self.assertEqual(action.shape, (16,))
|
||||
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
|
||||
|
||||
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
|
||||
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
|
||||
sampler_fn = getattr(rollout_module, "sample_task_state", None)
|
||||
policy_factory = getattr(rollout_module, "make_policy", None)
|
||||
self.assertIsNotNone(
|
||||
sampler_fn,
|
||||
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
policy_factory,
|
||||
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
|
||||
)
|
||||
|
||||
task_state = sampler_fn("sim_air_insert_ring_bar")
|
||||
self.assertEqual(
|
||||
list(task_state.keys()),
|
||||
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
||||
)
|
||||
|
||||
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
|
||||
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
|
||||
|
||||
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
try:
|
||||
env.reset(task_state)
|
||||
action = policy.predict(task_state, 0)
|
||||
env.step(action)
|
||||
self.assertIsNotNone(env.obs)
|
||||
self.assertIn("qpos", env.obs)
|
||||
self.assertIn("images", env.obs)
|
||||
finally:
|
||||
env.exit_flag = True
|
||||
cam_thread = getattr(env, "cam_thread", None)
|
||||
if cam_thread is not None:
|
||||
cam_thread.join(timeout=1.0)
|
||||
viewer = getattr(env, "viewer", None)
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user