feat(env): add strict air insertion reward and success logic

This commit is contained in:
Logic
2026-04-23 17:40:46 +08:00
parent f1ede7690f
commit a837a982f7
2 changed files with 235 additions and 3 deletions

View File

@@ -10,6 +10,19 @@ from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
RING_JOINT_NAME = "ring_block_joint" RING_JOINT_NAME = "ring_block_joint"
BAR_JOINT_NAME = "bar_block_joint" BAR_JOINT_NAME = "bar_block_joint"
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat") REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
RING_GEOM_NAMES = (
"ring_block_north",
"ring_block_south",
"ring_block_east",
"ring_block_west",
)
BAR_GEOM_NAMES = ("bar_block",)
LEFT_GRIPPER_GEOM_NAMES = ("l_finger_left", "r_finger_left")
RIGHT_GRIPPER_GEOM_NAMES = ("l_finger_right", "r_finger_right")
TABLE_GEOM_NAME = "table"
RING_APERTURE_HALF_WIDTH = 0.016
RING_HALF_THICKNESS = 0.009
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
def _set_free_joint_pose(joint, position, quat): def _set_free_joint_pose(joint, position, quat):
@@ -42,7 +55,95 @@ def get_ring_bar_env_state(mj_data):
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64) return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
def _normalize_contact_pairs(contact_pairs):
return {frozenset(pair) for pair in contact_pairs}
def _has_any_object_contact(contact_set, object_geom_names, other_geom_names):
return any(
frozenset((object_geom_name, other_geom_name)) in contact_set
for object_geom_name in object_geom_names
for other_geom_name in other_geom_names
)
def _object_is_airborne(contact_set, object_geom_names):
return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,))
def _quat_to_rotation_matrix(quat):
quat = np.asarray(quat, dtype=np.float64)
quat /= np.linalg.norm(quat)
w, x, y, z = quat
return np.array(
[
[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)],
[2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)],
[2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)],
],
dtype=np.float64,
)
def _split_env_state(env_state):
env_state = np.asarray(env_state, dtype=np.float64)
if env_state.shape != (14,):
raise ValueError(f"env_state must have shape (14,), got {env_state.shape}")
return (
env_state[:3],
env_state[3:7],
env_state[7:10],
env_state[10:14],
)
def bar_fully_inserted_through_ring(env_state):
ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state)
ring_rot = _quat_to_rotation_matrix(ring_quat)
bar_rot = _quat_to_rotation_matrix(bar_quat)
bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos)
bar_rot_in_ring = ring_rot.T @ bar_rot
projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES
spans_ring_thickness = (
bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS
and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS
)
fits_aperture = (
abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH
and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH
)
return bool(spans_ring_thickness and fits_aperture)
def compute_air_insert_reward(contact_pairs, env_state):
contact_set = _normalize_contact_pairs(contact_pairs)
reward = 0
if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES):
reward += 1
if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES):
reward += 1
ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES)
bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES)
if ring_airborne:
reward += 1
if bar_airborne:
reward += 1
if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state):
reward += 1
return reward
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl): class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_reward = 5
def reset(self, task_state): def reset(self, task_state):
set_ring_bar_task_state(self.mj_data, task_state) set_ring_bar_task_state(self.mj_data, task_state)
DualDianaMed.reset(self) DualDianaMed.reset(self)
@@ -66,6 +167,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
return get_ring_bar_env_state(self.mj_data) return get_ring_bar_env_state(self.mj_data)
def _get_reward(self): def _get_reward(self):
raise NotImplementedError( contact_pairs = []
"Task 2 wires reset/state only; reward logic is implemented in a later task." for collision_num in range(self.mj_data.ncon):
) geom1 = self.mj_data.contact[collision_num].geom1
geom2 = self.mj_data.contact[collision_num].geom2
contact_pairs.append(
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
)
return compute_air_insert_reward(contact_pairs, self.get_env_state())

View File

@@ -174,5 +174,131 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
) )
class AirInsertRewardAndSuccessTest(unittest.TestCase):
@staticmethod
def _make_env_state(
ring_pos=(0.0, 0.0, 0.50),
ring_quat=(1.0, 0.0, 0.0, 0.0),
bar_pos=(0.0, 0.0, 0.50),
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
):
return np.array(
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
dtype=np.float64,
)
def test_compute_air_insert_reward_counts_left_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
self.assertIsNotNone(
reward_fn,
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 1)
def test_compute_air_insert_reward_counts_right_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 2)
def test_compute_air_insert_reward_counts_lift_stages(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
self.assertEqual(reward, 4)
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertIsNotNone(
success_fn,
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
)
self.assertTrue(
success_fn(
self._make_env_state(),
)
)
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
)
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
)
)
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 3)
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()