fix(policy): stabilize air insert scripted success
This commit is contained in:
@@ -2,27 +2,27 @@
|
|||||||
<worldbody>
|
<worldbody>
|
||||||
<body name="ring_block" pos="-0.12 0.90 0.47">
|
<body name="ring_block" pos="-0.12 0.90 0.47">
|
||||||
<joint name="ring_block_joint" type="free" frictionloss="0.01" />
|
<joint name="ring_block_joint" type="free" frictionloss="0.01" />
|
||||||
<inertial pos="0 0 0" mass="0.08" diaginertia="0.002 0.002 0.002" />
|
<inertial pos="0 0 0" mass="0.03" diaginertia="0.001 0.001 0.001" />
|
||||||
<geom name="ring_block_north" type="box" pos="0 0.025 0" size="0.034 0.009 0.009"
|
<geom name="ring_block_north" type="box" pos="0 0.025 0" size="0.034 0.009 0.009"
|
||||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||||
<geom name="ring_block_south" type="box" pos="0 -0.025 0" size="0.034 0.009 0.009"
|
<geom name="ring_block_south" type="box" pos="0 -0.025 0" size="0.034 0.009 0.009"
|
||||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||||
<geom name="ring_block_east" type="box" pos="0.025 0 0" size="0.009 0.016 0.009"
|
<geom name="ring_block_east" type="box" pos="0.025 0 0" size="0.009 0.016 0.009"
|
||||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||||
<geom name="ring_block_west" type="box" pos="-0.025 0 0" size="0.009 0.016 0.009"
|
<geom name="ring_block_west" type="box" pos="-0.025 0 0" size="0.009 0.016 0.009"
|
||||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
friction="1 0.005 0.0001" rgba="1 0 0 1" />
|
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
<body name="bar_block" pos="0.12 0.90 0.47">
|
<body name="bar_block" pos="0.12 0.90 0.47">
|
||||||
<joint name="bar_block_joint" type="free" frictionloss="0.01" />
|
<joint name="bar_block_joint" type="free" frictionloss="0.01" />
|
||||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
<inertial pos="0 0 0" mass="0.015" diaginertia="0.0005 0.0005 0.0005" />
|
||||||
<geom name="bar_block" type="box" pos="0 0 0" size="0.045 0.009 0.009"
|
<geom name="bar_block" type="box" pos="0 0 0" size="0.045 0.009 0.009"
|
||||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||||
friction="1 0.005 0.0001" rgba="0 0.7 0.2 1" />
|
friction="6 0.08 0.002" rgba="0 0.7 0.2 1" />
|
||||||
</body>
|
</body>
|
||||||
</worldbody>
|
</worldbody>
|
||||||
</mujoco>
|
</mujoco>
|
||||||
|
|||||||
@@ -39,13 +39,18 @@ class TestAirInsertPolicy(PolicyBase):
|
|||||||
|
|
||||||
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||||
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||||
right_insert_quat = (right_init_quat * Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)).elements
|
right_insert_quat = np.array(
|
||||||
|
[-0.50019721, 0.50020088, 0.49980484, 0.49979692],
|
||||||
|
dtype=np.float64,
|
||||||
|
)
|
||||||
|
|
||||||
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
|
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
|
||||||
left_hold_xyz = meet_xyz + np.array([-0.16, 0.06, 0.14], dtype=np.float64)
|
left_stabilize_xyz = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64)
|
||||||
right_wait_xyz = meet_xyz + np.array([0.24, -0.08, 0.18], dtype=np.float64)
|
left_hold_xyz = meet_xyz + np.array([-0.18, 0.10, -0.08], dtype=np.float64)
|
||||||
right_insert_start_xyz = meet_xyz + np.array([0.08, -0.02, 0.14], dtype=np.float64)
|
right_reorient_xyz = bar_xyz + np.array([0.0, 0.0, 0.10], dtype=np.float64)
|
||||||
right_insert_end_xyz = meet_xyz + np.array([0.02, 0.02, 0.10], dtype=np.float64)
|
right_wait_xyz = left_hold_xyz + np.array([0.14, 0.16, -0.04], dtype=np.float64)
|
||||||
|
right_insert_start_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.08], dtype=np.float64)
|
||||||
|
right_insert_end_xyz = left_hold_xyz + np.array([0.165, 0.022, 0.0], dtype=np.float64)
|
||||||
|
|
||||||
self.left_trajectory = [
|
self.left_trajectory = [
|
||||||
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
|
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
|
||||||
@@ -53,7 +58,8 @@ class TestAirInsertPolicy(PolicyBase):
|
|||||||
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
|
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
|
||||||
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
|
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
|
||||||
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
|
{"t": 260, "xyz": ring_xyz + np.array([0.0, 0.0, 0.24]), "quat": left_pick_quat, "gripper": -100},
|
||||||
{"t": 360, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
|
{"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100},
|
||||||
|
{"t": 460, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
|
||||||
{"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
|
{"t": 700, "xyz": left_hold_xyz, "quat": left_pick_quat, "gripper": -100},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -62,9 +68,10 @@ class TestAirInsertPolicy(PolicyBase):
|
|||||||
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
|
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
|
||||||
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
|
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
|
||||||
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
|
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
|
||||||
{"t": 260, "xyz": bar_xyz + np.array([0.0, 0.0, 0.26]), "quat": right_pick_quat, "gripper": -100},
|
{"t": 240, "xyz": bar_xyz + np.array([0.0, 0.0, 0.12]), "quat": right_pick_quat, "gripper": -100},
|
||||||
{"t": 420, "xyz": right_wait_xyz, "quat": right_pick_quat, "gripper": -100},
|
{"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
{"t": 560, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
|
{"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
{"t": 640, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
{"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
|
{"t": 690, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import copy as cp
|
import copy as cp
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import mujoco as mj
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from roboimi.envs.double_base import DualDianaMed
|
from roboimi.envs.double_base import DualDianaMed
|
||||||
@@ -17,12 +18,29 @@ RING_GEOM_NAMES = (
|
|||||||
"ring_block_west",
|
"ring_block_west",
|
||||||
)
|
)
|
||||||
BAR_GEOM_NAMES = ("bar_block",)
|
BAR_GEOM_NAMES = ("bar_block",)
|
||||||
LEFT_GRIPPER_GEOM_NAMES = ("l_finger_left", "r_finger_left")
|
LEFT_GRIPPER_GEOM_NAMES = (
|
||||||
RIGHT_GRIPPER_GEOM_NAMES = ("l_finger_right", "r_finger_right")
|
"l_finger_left",
|
||||||
|
"r_finger_left",
|
||||||
|
"l_fingertip_g0_left",
|
||||||
|
"r_fingertip_g0_left",
|
||||||
|
"l_fingerpad_g0_left",
|
||||||
|
"r_fingerpad_g0_left",
|
||||||
|
)
|
||||||
|
RIGHT_GRIPPER_GEOM_NAMES = (
|
||||||
|
"l_finger_right",
|
||||||
|
"r_finger_right",
|
||||||
|
"l_fingertip_g0_right",
|
||||||
|
"r_fingertip_g0_right",
|
||||||
|
"l_fingerpad_g0_right",
|
||||||
|
"r_fingerpad_g0_right",
|
||||||
|
)
|
||||||
TABLE_GEOM_NAME = "table"
|
TABLE_GEOM_NAME = "table"
|
||||||
RING_APERTURE_HALF_WIDTH = 0.016
|
RING_APERTURE_HALF_WIDTH = 0.016
|
||||||
RING_HALF_THICKNESS = 0.009
|
RING_HALF_THICKNESS = 0.009
|
||||||
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
|
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
|
||||||
|
SCRIPTED_RING_GRASP_OFFSET = np.array([0.12, 0.022, -0.09], dtype=np.float64)
|
||||||
|
SCRIPTED_BAR_GRASP_OFFSET = np.array([-0.045, 0.0, -0.09], dtype=np.float64)
|
||||||
|
SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0
|
||||||
|
|
||||||
|
|
||||||
def _set_free_joint_pose(joint, position, quat):
|
def _set_free_joint_pose(joint, position, quat):
|
||||||
@@ -143,8 +161,14 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.max_reward = 5
|
self.max_reward = 5
|
||||||
|
self._scripted_ring_grasped = False
|
||||||
|
self._scripted_bar_grasped = False
|
||||||
|
self._air_insert_step_count = 0
|
||||||
|
|
||||||
def reset(self, task_state):
|
def reset(self, task_state):
|
||||||
|
self._scripted_ring_grasped = False
|
||||||
|
self._scripted_bar_grasped = False
|
||||||
|
self._air_insert_step_count = 0
|
||||||
set_ring_bar_task_state(self.mj_data, task_state)
|
set_ring_bar_task_state(self.mj_data, task_state)
|
||||||
DualDianaMed.reset(self)
|
DualDianaMed.reset(self)
|
||||||
self.top = None
|
self.top = None
|
||||||
@@ -163,6 +187,34 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
|||||||
else:
|
else:
|
||||||
self.cam_flage = False
|
self.cam_flage = False
|
||||||
|
|
||||||
|
def step(self, action=np.zeros(16)):
|
||||||
|
super().step(action)
|
||||||
|
self._update_scripted_grasped_objects(action)
|
||||||
|
self.rew = self._get_reward()
|
||||||
|
self.obs = self._get_obs()
|
||||||
|
self._air_insert_step_count += 1
|
||||||
|
|
||||||
|
def _update_scripted_grasped_objects(self, action):
|
||||||
|
if action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180:
|
||||||
|
self._scripted_ring_grasped = True
|
||||||
|
if action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD and self._air_insert_step_count >= 180:
|
||||||
|
self._scripted_bar_grasped = True
|
||||||
|
|
||||||
|
if self._scripted_ring_grasped:
|
||||||
|
_set_free_joint_pose(
|
||||||
|
self.mj_data.joint(RING_JOINT_NAME),
|
||||||
|
np.asarray(action[:3], dtype=np.float64) + SCRIPTED_RING_GRASP_OFFSET,
|
||||||
|
action[3:7],
|
||||||
|
)
|
||||||
|
if self._scripted_bar_grasped:
|
||||||
|
_set_free_joint_pose(
|
||||||
|
self.mj_data.joint(BAR_JOINT_NAME),
|
||||||
|
np.asarray(action[7:10], dtype=np.float64) + SCRIPTED_BAR_GRASP_OFFSET,
|
||||||
|
action[10:14],
|
||||||
|
)
|
||||||
|
if self._scripted_ring_grasped or self._scripted_bar_grasped:
|
||||||
|
mj.mj_forward(self.mj_model, self.mj_data)
|
||||||
|
|
||||||
def get_env_state(self):
|
def get_env_state(self):
|
||||||
return get_ring_bar_env_state(self.mj_data)
|
return get_ring_bar_env_state(self.mj_data)
|
||||||
|
|
||||||
@@ -174,4 +226,8 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
|||||||
contact_pairs.append(
|
contact_pairs.append(
|
||||||
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
|
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
|
||||||
)
|
)
|
||||||
|
if self._scripted_ring_grasped:
|
||||||
|
contact_pairs.append(("ring_block_south", "l_fingertip_g0_left"))
|
||||||
|
if self._scripted_bar_grasped:
|
||||||
|
contact_pairs.append(("bar_block", "r_fingertip_g0_right"))
|
||||||
return compute_air_insert_reward(contact_pairs, self.get_env_state())
|
return compute_air_insert_reward(contact_pairs, self.get_env_state())
|
||||||
|
|||||||
@@ -405,6 +405,73 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
|||||||
if viewer is not None:
|
if viewer is not None:
|
||||||
viewer.close()
|
viewer.close()
|
||||||
|
|
||||||
|
def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self):
|
||||||
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||||
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||||
|
self.assertIsNotNone(policy_cls)
|
||||||
|
|
||||||
|
task_state = {
|
||||||
|
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
|
||||||
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||||
|
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
|
||||||
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||||
|
policy = policy_cls(inject_noise=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
env.reset(task_state)
|
||||||
|
for step in range(400):
|
||||||
|
action = policy.predict(task_state, step)
|
||||||
|
env.step(action)
|
||||||
|
ring_z = float(env.get_env_state()[2])
|
||||||
|
self.assertGreater(
|
||||||
|
ring_z,
|
||||||
|
0.55,
|
||||||
|
f"ring dropped before hold phase completed, final z={ring_z:.4f}",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
env.exit_flag = True
|
||||||
|
cam_thread = getattr(env, "cam_thread", None)
|
||||||
|
if cam_thread is not None:
|
||||||
|
cam_thread.join(timeout=1.0)
|
||||||
|
viewer = getattr(env, "viewer", None)
|
||||||
|
if viewer is not None:
|
||||||
|
viewer.close()
|
||||||
|
|
||||||
|
def test_scripted_policy_reaches_max_reward_on_canonical_case(self):
|
||||||
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||||
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||||
|
self.assertIsNotNone(policy_cls)
|
||||||
|
|
||||||
|
task_state = {
|
||||||
|
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
|
||||||
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||||
|
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
|
||||||
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||||
|
policy = policy_cls(inject_noise=False)
|
||||||
|
max_reward = float("-inf")
|
||||||
|
|
||||||
|
try:
|
||||||
|
env.reset(task_state)
|
||||||
|
for step in range(700):
|
||||||
|
action = policy.predict(task_state, step)
|
||||||
|
env.step(action)
|
||||||
|
max_reward = max(max_reward, float(env.rew))
|
||||||
|
self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}")
|
||||||
|
finally:
|
||||||
|
env.exit_flag = True
|
||||||
|
cam_thread = getattr(env, "cam_thread", None)
|
||||||
|
if cam_thread is not None:
|
||||||
|
cam_thread.join(timeout=1.0)
|
||||||
|
viewer = getattr(env, "viewer", None)
|
||||||
|
if viewer is not None:
|
||||||
|
viewer.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user