feat(env): add strict air insertion reward and success logic

This commit is contained in:
Logic
2026-04-23 17:40:46 +08:00
parent f1ede7690f
commit a837a982f7
2 changed files with 235 additions and 3 deletions

View File

@@ -174,5 +174,131 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
)
class AirInsertRewardAndSuccessTest(unittest.TestCase):
@staticmethod
def _make_env_state(
ring_pos=(0.0, 0.0, 0.50),
ring_quat=(1.0, 0.0, 0.0, 0.0),
bar_pos=(0.0, 0.0, 0.50),
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
):
return np.array(
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
dtype=np.float64,
)
def test_compute_air_insert_reward_counts_left_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
self.assertIsNotNone(
reward_fn,
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 1)
def test_compute_air_insert_reward_counts_right_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 2)
def test_compute_air_insert_reward_counts_lift_stages(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
self.assertEqual(reward, 4)
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertIsNotNone(
success_fn,
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
)
self.assertTrue(
success_fn(
self._make_env_state(),
)
)
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
)
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
)
)
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 3)
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 5)
if __name__ == "__main__":
unittest.main()