feat(env): add strict air insertion reward and success logic
This commit is contained in:
@@ -10,6 +10,19 @@ from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
|
|||||||
RING_JOINT_NAME = "ring_block_joint"
|
RING_JOINT_NAME = "ring_block_joint"
|
||||||
BAR_JOINT_NAME = "bar_block_joint"
|
BAR_JOINT_NAME = "bar_block_joint"
|
||||||
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
|
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
|
||||||
|
RING_GEOM_NAMES = (
|
||||||
|
"ring_block_north",
|
||||||
|
"ring_block_south",
|
||||||
|
"ring_block_east",
|
||||||
|
"ring_block_west",
|
||||||
|
)
|
||||||
|
BAR_GEOM_NAMES = ("bar_block",)
|
||||||
|
LEFT_GRIPPER_GEOM_NAMES = ("l_finger_left", "r_finger_left")
|
||||||
|
RIGHT_GRIPPER_GEOM_NAMES = ("l_finger_right", "r_finger_right")
|
||||||
|
TABLE_GEOM_NAME = "table"
|
||||||
|
RING_APERTURE_HALF_WIDTH = 0.016
|
||||||
|
RING_HALF_THICKNESS = 0.009
|
||||||
|
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
def _set_free_joint_pose(joint, position, quat):
|
def _set_free_joint_pose(joint, position, quat):
|
||||||
@@ -42,7 +55,95 @@ def get_ring_bar_env_state(mj_data):
|
|||||||
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
|
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_contact_pairs(contact_pairs):
|
||||||
|
return {frozenset(pair) for pair in contact_pairs}
|
||||||
|
|
||||||
|
|
||||||
|
def _has_any_object_contact(contact_set, object_geom_names, other_geom_names):
|
||||||
|
return any(
|
||||||
|
frozenset((object_geom_name, other_geom_name)) in contact_set
|
||||||
|
for object_geom_name in object_geom_names
|
||||||
|
for other_geom_name in other_geom_names
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _object_is_airborne(contact_set, object_geom_names):
|
||||||
|
return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,))
|
||||||
|
|
||||||
|
|
||||||
|
def _quat_to_rotation_matrix(quat):
|
||||||
|
quat = np.asarray(quat, dtype=np.float64)
|
||||||
|
quat /= np.linalg.norm(quat)
|
||||||
|
w, x, y, z = quat
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)],
|
||||||
|
[2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)],
|
||||||
|
[2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)],
|
||||||
|
],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_env_state(env_state):
|
||||||
|
env_state = np.asarray(env_state, dtype=np.float64)
|
||||||
|
if env_state.shape != (14,):
|
||||||
|
raise ValueError(f"env_state must have shape (14,), got {env_state.shape}")
|
||||||
|
return (
|
||||||
|
env_state[:3],
|
||||||
|
env_state[3:7],
|
||||||
|
env_state[7:10],
|
||||||
|
env_state[10:14],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bar_fully_inserted_through_ring(env_state):
|
||||||
|
ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state)
|
||||||
|
ring_rot = _quat_to_rotation_matrix(ring_quat)
|
||||||
|
bar_rot = _quat_to_rotation_matrix(bar_quat)
|
||||||
|
|
||||||
|
bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos)
|
||||||
|
bar_rot_in_ring = ring_rot.T @ bar_rot
|
||||||
|
projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES
|
||||||
|
|
||||||
|
spans_ring_thickness = (
|
||||||
|
bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS
|
||||||
|
and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS
|
||||||
|
)
|
||||||
|
fits_aperture = (
|
||||||
|
abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH
|
||||||
|
and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH
|
||||||
|
)
|
||||||
|
return bool(spans_ring_thickness and fits_aperture)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_air_insert_reward(contact_pairs, env_state):
|
||||||
|
contact_set = _normalize_contact_pairs(contact_pairs)
|
||||||
|
reward = 0
|
||||||
|
|
||||||
|
if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES):
|
||||||
|
reward += 1
|
||||||
|
if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES):
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES)
|
||||||
|
bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES)
|
||||||
|
if ring_airborne:
|
||||||
|
reward += 1
|
||||||
|
if bar_airborne:
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state):
|
||||||
|
reward += 1
|
||||||
|
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_reward = 5
|
||||||
|
|
||||||
def reset(self, task_state):
|
def reset(self, task_state):
|
||||||
set_ring_bar_task_state(self.mj_data, task_state)
|
set_ring_bar_task_state(self.mj_data, task_state)
|
||||||
DualDianaMed.reset(self)
|
DualDianaMed.reset(self)
|
||||||
@@ -66,6 +167,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
|||||||
return get_ring_bar_env_state(self.mj_data)
|
return get_ring_bar_env_state(self.mj_data)
|
||||||
|
|
||||||
def _get_reward(self):
|
def _get_reward(self):
|
||||||
raise NotImplementedError(
|
contact_pairs = []
|
||||||
"Task 2 wires reset/state only; reward logic is implemented in a later task."
|
for collision_num in range(self.mj_data.ncon):
|
||||||
)
|
geom1 = self.mj_data.contact[collision_num].geom1
|
||||||
|
geom2 = self.mj_data.contact[collision_num].geom2
|
||||||
|
contact_pairs.append(
|
||||||
|
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
|
||||||
|
)
|
||||||
|
return compute_air_insert_reward(contact_pairs, self.get_env_state())
|
||||||
|
|||||||
@@ -174,5 +174,131 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||||
|
@staticmethod
|
||||||
|
def _make_env_state(
|
||||||
|
ring_pos=(0.0, 0.0, 0.50),
|
||||||
|
ring_quat=(1.0, 0.0, 0.0, 0.0),
|
||||||
|
bar_pos=(0.0, 0.0, 0.50),
|
||||||
|
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
|
||||||
|
):
|
||||||
|
return np.array(
|
||||||
|
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compute_air_insert_reward_counts_left_contact_stage(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
reward_fn,
|
||||||
|
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
reward = reward_fn(
|
||||||
|
contact_pairs=[
|
||||||
|
("ring_block_north", "l_finger_left"),
|
||||||
|
("ring_block_north", "table"),
|
||||||
|
("bar_block", "table"),
|
||||||
|
],
|
||||||
|
env_state=self._make_env_state(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(reward, 1)
|
||||||
|
|
||||||
|
def test_compute_air_insert_reward_counts_right_contact_stage(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||||
|
|
||||||
|
reward = reward_fn(
|
||||||
|
contact_pairs=[
|
||||||
|
("ring_block_north", "l_finger_left"),
|
||||||
|
("bar_block", "l_finger_right"),
|
||||||
|
("ring_block_north", "table"),
|
||||||
|
("bar_block", "table"),
|
||||||
|
],
|
||||||
|
env_state=self._make_env_state(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(reward, 2)
|
||||||
|
|
||||||
|
def test_compute_air_insert_reward_counts_lift_stages(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||||
|
|
||||||
|
reward = reward_fn(
|
||||||
|
contact_pairs=[
|
||||||
|
("ring_block_north", "l_finger_left"),
|
||||||
|
("bar_block", "l_finger_right"),
|
||||||
|
],
|
||||||
|
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(reward, 4)
|
||||||
|
|
||||||
|
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||||
|
self.assertIsNotNone(
|
||||||
|
success_fn,
|
||||||
|
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
success_fn(
|
||||||
|
self._make_env_state(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||||
|
|
||||||
|
self.assertFalse(
|
||||||
|
success_fn(
|
||||||
|
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||||
|
|
||||||
|
self.assertFalse(
|
||||||
|
success_fn(
|
||||||
|
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||||
|
|
||||||
|
reward = reward_fn(
|
||||||
|
contact_pairs=[
|
||||||
|
("ring_block_north", "l_finger_left"),
|
||||||
|
("bar_block", "l_finger_right"),
|
||||||
|
("ring_block_north", "table"),
|
||||||
|
],
|
||||||
|
env_state=self._make_env_state(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(reward, 3)
|
||||||
|
|
||||||
|
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
|
||||||
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||||
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||||
|
|
||||||
|
reward = reward_fn(
|
||||||
|
contact_pairs=[
|
||||||
|
("ring_block_north", "l_finger_left"),
|
||||||
|
("bar_block", "l_finger_right"),
|
||||||
|
],
|
||||||
|
env_state=self._make_env_state(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(reward, 5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user