From a837a982f7b4f63a6256750ee8fe97c6ace7b262 Mon Sep 17 00:00:00 2001 From: Logic Date: Thu, 23 Apr 2026 17:40:46 +0800 Subject: [PATCH] feat(env): add strict air insertion reward and success logic --- roboimi/envs/double_air_insert_env.py | 112 ++++++++++++++++++++++- tests/test_air_insert_env.py | 126 ++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 3 deletions(-) diff --git a/roboimi/envs/double_air_insert_env.py b/roboimi/envs/double_air_insert_env.py index 63f489f..d1955bb 100644 --- a/roboimi/envs/double_air_insert_env.py +++ b/roboimi/envs/double_air_insert_env.py @@ -10,6 +10,19 @@ from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl RING_JOINT_NAME = "ring_block_joint" BAR_JOINT_NAME = "bar_block_joint" REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat") +RING_GEOM_NAMES = ( + "ring_block_north", + "ring_block_south", + "ring_block_east", + "ring_block_west", +) +BAR_GEOM_NAMES = ("bar_block",) +LEFT_GRIPPER_GEOM_NAMES = ("l_finger_left", "r_finger_left") +RIGHT_GRIPPER_GEOM_NAMES = ("l_finger_right", "r_finger_right") +TABLE_GEOM_NAME = "table" +RING_APERTURE_HALF_WIDTH = 0.016 +RING_HALF_THICKNESS = 0.009 +BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64) def _set_free_joint_pose(joint, position, quat): @@ -42,7 +55,95 @@ def get_ring_bar_env_state(mj_data): return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64) +def _normalize_contact_pairs(contact_pairs): + return {frozenset(pair) for pair in contact_pairs} + + +def _has_any_object_contact(contact_set, object_geom_names, other_geom_names): + return any( + frozenset((object_geom_name, other_geom_name)) in contact_set + for object_geom_name in object_geom_names + for other_geom_name in other_geom_names + ) + + +def _object_is_airborne(contact_set, object_geom_names): + return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,)) + + +def _quat_to_rotation_matrix(quat): + quat = np.asarray(quat, dtype=np.float64) + quat /= np.linalg.norm(quat) + w, x, y, z = quat + return np.array( + [ + [1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)], + [2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)], + [2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)], + ], + dtype=np.float64, + ) + + +def _split_env_state(env_state): + env_state = np.asarray(env_state, dtype=np.float64) + if env_state.shape != (14,): + raise ValueError(f"env_state must have shape (14,), got {env_state.shape}") + return ( + env_state[:3], + env_state[3:7], + env_state[7:10], + env_state[10:14], + ) + + +def bar_fully_inserted_through_ring(env_state): + ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state) + ring_rot = _quat_to_rotation_matrix(ring_quat) + bar_rot = _quat_to_rotation_matrix(bar_quat) + + bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos) + bar_rot_in_ring = ring_rot.T @ bar_rot + projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES + + spans_ring_thickness = ( + bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS + and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS + ) + fits_aperture = ( + abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH + and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH + ) + return bool(spans_ring_thickness and fits_aperture) + + +def compute_air_insert_reward(contact_pairs, env_state): + contact_set = _normalize_contact_pairs(contact_pairs) + reward = 0 + + if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES): + reward += 1 + if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES): + reward += 1 + + ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES) + bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES) + if ring_airborne: + reward += 1 + if bar_airborne: + reward += 1 + + if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state): + reward += 1 + + return reward + + class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_reward = 5 + def reset(self, task_state): set_ring_bar_task_state(self.mj_data, task_state) DualDianaMed.reset(self) @@ -66,6 +167,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): return get_ring_bar_env_state(self.mj_data) def _get_reward(self): - raise NotImplementedError( - "Task 2 wires reset/state only; reward logic is implemented in a later task." - ) + contact_pairs = [] + for collision_num in range(self.mj_data.ncon): + geom1 = self.mj_data.contact[collision_num].geom1 + geom2 = self.mj_data.contact[collision_num].geom2 + contact_pairs.append( + (self.getID2Name("geom", geom1), self.getID2Name("geom", geom2)) + ) + return compute_air_insert_reward(contact_pairs, self.get_env_state()) diff --git a/tests/test_air_insert_env.py b/tests/test_air_insert_env.py index 3f5237c..8811ba9 100644 --- a/tests/test_air_insert_env.py +++ b/tests/test_air_insert_env.py @@ -174,5 +174,131 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase): ) +class AirInsertRewardAndSuccessTest(unittest.TestCase): + @staticmethod + def _make_env_state( + ring_pos=(0.0, 0.0, 0.50), + ring_quat=(1.0, 0.0, 0.0, 0.0), + bar_pos=(0.0, 0.0, 0.50), + bar_quat=(0.70710678, 0.0, 0.70710678, 0.0), + ): + return np.array( + [*ring_pos, *ring_quat, *bar_pos, *bar_quat], + dtype=np.float64, + ) + + def test_compute_air_insert_reward_counts_left_contact_stage(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + self.assertIsNotNone( + reward_fn, + "Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward", + ) + + reward = reward_fn( + contact_pairs=[ + ("ring_block_north", "l_finger_left"), + ("ring_block_north", "table"), + ("bar_block", "table"), + ], + env_state=self._make_env_state(), + ) + + self.assertEqual(reward, 1) + + def test_compute_air_insert_reward_counts_right_contact_stage(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + + reward = reward_fn( + contact_pairs=[ + ("ring_block_north", "l_finger_left"), + ("bar_block", "l_finger_right"), + ("ring_block_north", "table"), + ("bar_block", "table"), + ], + env_state=self._make_env_state(), + ) + + self.assertEqual(reward, 2) + + def test_compute_air_insert_reward_counts_lift_stages(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + + reward = reward_fn( + contact_pairs=[ + ("ring_block_north", "l_finger_left"), + ("bar_block", "l_finger_right"), + ], + env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)), + ) + + self.assertEqual(reward, 4) + + def test_bar_fully_inserted_through_ring_accepts_true_positive(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) + self.assertIsNotNone( + success_fn, + "Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring", + ) + + self.assertTrue( + success_fn( + self._make_env_state(), + ) + ) + + def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) + + self.assertFalse( + success_fn( + self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)), + ) + ) + + def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None) + + self.assertFalse( + success_fn( + self._make_env_state(bar_pos=(0.0, 0.0, 0.56)), + ) + ) + + def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + + reward = reward_fn( + contact_pairs=[ + ("ring_block_north", "l_finger_left"), + ("bar_block", "l_finger_right"), + ("ring_block_north", "table"), + ], + env_state=self._make_env_state(), + ) + + self.assertEqual(reward, 3) + + def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self): + air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") + reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) + + reward = reward_fn( + contact_pairs=[ + ("ring_block_north", "l_finger_left"), + ("bar_block", "l_finger_right"), + ], + env_state=self._make_env_state(), + ) + + self.assertEqual(reward, 5) + + if __name__ == "__main__": unittest.main()