feat(sim): switch air insert task to socket peg
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
<mujoco model="bi_diana_ring_bar">
|
||||
<mujoco model="bi_diana_socket_peg">
|
||||
<include file="./empty_world.xml" />
|
||||
<include file="./table_square.xml" />
|
||||
<include file="./ring_bar_objects.xml" />
|
||||
<include file="./socket_peg_objects.xml" />
|
||||
<include file="./BiDianaMed_rethink.xml" />
|
||||
</mujoco>
|
||||
@@ -1,28 +0,0 @@
|
||||
<mujoco model="ring_bar_objects">
|
||||
<worldbody>
|
||||
<body name="ring_block" pos="-0.12 0.90 0.47">
|
||||
<joint name="ring_block_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.03" diaginertia="0.001 0.001 0.001" />
|
||||
<geom name="ring_block_north" type="box" pos="0 0.025 0" size="0.034 0.009 0.009"
|
||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||
<geom name="ring_block_south" type="box" pos="0 -0.025 0" size="0.034 0.009 0.009"
|
||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||
<geom name="ring_block_east" type="box" pos="0.025 0 0" size="0.009 0.016 0.009"
|
||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||
<geom name="ring_block_west" type="box" pos="-0.025 0 0" size="0.009 0.016 0.009"
|
||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||
friction="4 0.05 0.001" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="bar_block" pos="0.12 0.90 0.47">
|
||||
<joint name="bar_block_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.015" diaginertia="0.0005 0.0005 0.0005" />
|
||||
<geom name="bar_block" type="box" pos="0 0 0" size="0.045 0.009 0.009"
|
||||
contype="1" conaffinity="1" condim="4" solimp="2 1 0.01" solref="0.01 1"
|
||||
friction="6 0.08 0.002" rgba="0 0.7 0.2 1" />
|
||||
</body>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
@@ -0,0 +1,19 @@
|
||||
<mujoco model="socket_peg_objects">
|
||||
<worldbody>
|
||||
<body name="peg" pos="0.12 0.90 0.46">
|
||||
<joint name="red_peg_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.06 0.01 0.01" type="box" name="red_peg" rgba="1 0 0 1" />
|
||||
</body>
|
||||
|
||||
<body name="socket" pos="-0.12 0.90 0.472">
|
||||
<joint name="blue_socket_joint" type="free" frictionloss="0.01" />
|
||||
<inertial pos="0 0 0" mass="0.05" diaginertia="0.002 0.002 0.002" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 -0.02" size="0.06 0.018 0.002" type="box" name="socket-1" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0 0.02" size="0.06 0.018 0.002" type="box" name="socket-2" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 0.02 0" size="0.06 0.002 0.018" type="box" name="socket-3" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.05 0.001" pos="0 -0.02 0" size="0.06 0.002 0.018" type="box" name="socket-4" rgba="0 0 1 1" />
|
||||
<geom condim="4" solimp="2 1 0.01" solref="0.01 1" friction="1 0.005 0.0001" pos="0 0 0" size="0.04 0.01 0.01" type="box" name="pin" rgba="1 0 0 1" />
|
||||
</body>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
@@ -7,7 +7,7 @@
|
||||
<geom name="table" condim="4" contype="1" conaffinity="1" type="box" rgba="0.4 0.4 0.4 1" size="0.62 0.62 0.01" density="1500" friction="0.9 0.9 0.9"/>
|
||||
</body>
|
||||
<camera name="top" pos="0.0 1.0 2.0" fovy="44" mode="targetbody" target="table"/>
|
||||
<camera name="angle" pos="0.0 0.0 2.0" fovy="37" mode="targetbody" target="table"/>
|
||||
<camera name="left_side" pos="-0.55 0.85 0.85" fovy="65" mode="targetbody" target="table"/>
|
||||
<camera name="front" pos="0 0 0.8" fovy="65" mode="fixed" quat="0.7071 0.7071 0 0"/>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
|
||||
@@ -92,12 +92,12 @@ class BiDianaMed(ArmBase):
|
||||
return np.array([0.0, 0.0, 0.0, 1.57, 0.0, 0.0, 0.0])
|
||||
|
||||
|
||||
class BiDianaMedRingBar(ArmBase):
|
||||
class BiDianaMedSocketPeg(ArmBase):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
name="Bidiana_ring_bar",
|
||||
name="Bidiana_socket_peg",
|
||||
urdf_path="roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf",
|
||||
xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml",
|
||||
xml_path="roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml",
|
||||
gripper=None
|
||||
)
|
||||
self.left_arm = self.Arm(self, 'single', self.urdf_path)
|
||||
|
||||
@@ -5,16 +5,39 @@ from roboimi.demos.diana_policy import PolicyBase
|
||||
|
||||
|
||||
class TestAirInsertPolicy(PolicyBase):
|
||||
@staticmethod
|
||||
def _action_xyz_for_object_center(object_center, ee_quat, object_offset_local):
|
||||
return (
|
||||
np.asarray(object_center, dtype=np.float64)
|
||||
- np.asarray(Quaternion(ee_quat).rotate(object_offset_local), dtype=np.float64)
|
||||
)
|
||||
ACTION_OBJECT_Z_OFFSET = 0.078
|
||||
SOCKET_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64)
|
||||
PEG_GRASP_OFFSET = np.array([0.0, 0.0, 0.0], dtype=np.float64)
|
||||
SOCKET_OUTER_GRASP_STRATEGY = "socket_outer"
|
||||
LEGACY_GRASP_STRATEGY = "legacy"
|
||||
SOCKET_HOLD_Z = 0.85
|
||||
PEG_INSERT_START_OFFSET = np.array([0.105, 0.0, 0.0], dtype=np.float64)
|
||||
INSERT_START_T = 650
|
||||
INSERT_END_T = 700
|
||||
LEFT_SOCKET_GRIPPER_CLOSED = -70
|
||||
RIGHT_PEG_GRIPPER_CLOSED = -100
|
||||
SOCKET_APPROACH_Z = 1.05
|
||||
EPISODE_END_T = 1000
|
||||
|
||||
def __init__(self, inject_noise=False, grasp_strategy=SOCKET_OUTER_GRASP_STRATEGY):
|
||||
super().__init__(inject_noise=inject_noise)
|
||||
valid_strategies = {
|
||||
self.SOCKET_OUTER_GRASP_STRATEGY,
|
||||
self.LEGACY_GRASP_STRATEGY,
|
||||
}
|
||||
if grasp_strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Unsupported air insert grasp_strategy={grasp_strategy!r}; "
|
||||
f"expected one of {sorted(valid_strategies)}"
|
||||
)
|
||||
self.grasp_strategy = grasp_strategy
|
||||
|
||||
def generate_trajectory(self, task_state):
|
||||
ring_xyz = np.asarray(task_state["ring_pos"], dtype=np.float64)
|
||||
bar_xyz = np.asarray(task_state["bar_pos"], dtype=np.float64)
|
||||
return self._generate_socket_peg_trajectory(task_state)
|
||||
|
||||
def _generate_socket_peg_trajectory(self, task_state):
|
||||
socket_xyz = np.asarray(task_state["socket_pos"], dtype=np.float64)
|
||||
peg_xyz = np.asarray(task_state["peg_pos"], dtype=np.float64)
|
||||
|
||||
init_mocap_pose_left = np.array(
|
||||
[
|
||||
@@ -44,63 +67,137 @@ class TestAirInsertPolicy(PolicyBase):
|
||||
left_init_quat = Quaternion(init_mocap_pose_left[3:])
|
||||
right_init_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
|
||||
object_offset_local = np.array([0.0, 0.0, -0.09], dtype=np.float64)
|
||||
left_pick_quat = (left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||
left_hold_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=-90).elements
|
||||
right_pick_quat = (right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=90)).elements
|
||||
insert_quat_local = Quaternion([-0.50019721, 0.50020088, 0.49980484, 0.49979692])
|
||||
right_insert_quat = np.array(
|
||||
(Quaternion(left_hold_quat) * insert_quat_local).elements,
|
||||
left_pick_quat = (
|
||||
left_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45)
|
||||
).elements
|
||||
right_pick_quat = (
|
||||
right_init_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=45)
|
||||
).elements
|
||||
|
||||
socket_hold_action = np.array(
|
||||
[socket_xyz[0] - 0.078, socket_xyz[1], self.SOCKET_HOLD_Z], dtype=np.float64
|
||||
)
|
||||
|
||||
peg_init_xyz = peg_xyz + np.array(
|
||||
[0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET + 0.01]
|
||||
)
|
||||
peg_lift_center = np.array(
|
||||
[peg_xyz[0] + 0.078, socket_hold_action[1], self.SOCKET_HOLD_Z - 0.01],
|
||||
dtype=np.float64,
|
||||
)
|
||||
# The front camera looks along +Y, so visual right-to-left insertion is
|
||||
# world +X -> -X. With the socket XML in identity orientation, its
|
||||
# tunnel axis is local/world X, so the peg approaches from +X and stops
|
||||
# when its leading face reaches the socket's internal pin.
|
||||
peg_insert_end_center = np.array(
|
||||
[
|
||||
socket_hold_action[0] + 0.078 * 2 + 0.04 + 0.06 - 0.01,
|
||||
socket_hold_action[1],
|
||||
self.SOCKET_HOLD_Z - 0.01,
|
||||
],
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
meet_xyz = np.array([0.0, 1.0, 1.30], dtype=np.float64)
|
||||
ring_stabilize_center = ring_xyz + np.array([0.0, 0.0, 0.30], dtype=np.float64)
|
||||
ring_hold_center = meet_xyz + np.array([-0.10, 0.05, -0.16], dtype=np.float64)
|
||||
bar_reorient_center = bar_xyz + np.array([0.0, 0.0, 0.16], dtype=np.float64)
|
||||
bar_wait_center = ring_hold_center + np.array([0.05, -0.18, 0.0], dtype=np.float64)
|
||||
bar_insert_start_center = ring_hold_center + np.array([0.0, -0.075, 0.0], dtype=np.float64)
|
||||
bar_insert_end_center = ring_hold_center + np.array([0.0, 0.075, 0.0], dtype=np.float64)
|
||||
|
||||
left_stabilize_xyz = self._action_xyz_for_object_center(
|
||||
ring_stabilize_center, left_pick_quat, object_offset_local
|
||||
)
|
||||
left_hold_xyz = self._action_xyz_for_object_center(
|
||||
ring_hold_center, left_hold_quat, object_offset_local
|
||||
)
|
||||
right_reorient_xyz = self._action_xyz_for_object_center(
|
||||
bar_reorient_center, right_insert_quat, object_offset_local
|
||||
)
|
||||
right_wait_xyz = self._action_xyz_for_object_center(
|
||||
bar_wait_center, right_insert_quat, object_offset_local
|
||||
)
|
||||
right_insert_start_xyz = self._action_xyz_for_object_center(
|
||||
bar_insert_start_center, right_insert_quat, object_offset_local
|
||||
)
|
||||
right_insert_end_xyz = self._action_xyz_for_object_center(
|
||||
bar_insert_end_center, right_insert_quat, object_offset_local
|
||||
)
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 1, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 100},
|
||||
{"t": 80, "xyz": ring_xyz + np.array([0.0, 0.0, 0.22]), "quat": left_pick_quat, "gripper": 100},
|
||||
{"t": 150, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": 100},
|
||||
{"t": 180, "xyz": ring_xyz + np.array([0.0, 0.0, 0.08]), "quat": left_pick_quat, "gripper": -100},
|
||||
{"t": 260, "xyz": self._action_xyz_for_object_center(ring_xyz + np.array([0.0, 0.0, 0.24]), left_pick_quat, object_offset_local), "quat": left_pick_quat, "gripper": -100},
|
||||
{"t": 340, "xyz": left_stabilize_xyz, "quat": left_pick_quat, "gripper": -100},
|
||||
{"t": 460, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100},
|
||||
{"t": 700, "xyz": left_hold_xyz, "quat": left_hold_quat, "gripper": -100},
|
||||
{
|
||||
"t": 1,
|
||||
"xyz": init_mocap_pose_left[:3],
|
||||
"quat": init_mocap_pose_left[3:],
|
||||
"gripper": 100,
|
||||
},
|
||||
{
|
||||
"t": 130,
|
||||
"xyz": socket_xyz
|
||||
+ np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]),
|
||||
"quat": left_pick_quat,
|
||||
"gripper": 100,
|
||||
},
|
||||
{
|
||||
"t": 180,
|
||||
"xyz": socket_xyz
|
||||
+ np.array([-0.078, 0.0, self.ACTION_OBJECT_Z_OFFSET]),
|
||||
"quat": left_pick_quat,
|
||||
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": 450,
|
||||
"xyz": socket_hold_action,
|
||||
"quat": left_pick_quat,
|
||||
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": 750,
|
||||
"xyz": socket_hold_action,
|
||||
"quat": left_pick_quat,
|
||||
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": self.EPISODE_END_T,
|
||||
"xyz": socket_hold_action,
|
||||
"quat": left_pick_quat,
|
||||
"gripper": self.LEFT_SOCKET_GRIPPER_CLOSED,
|
||||
},
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 1, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 100},
|
||||
{"t": 80, "xyz": bar_xyz + np.array([0.0, 0.0, 0.22]), "quat": right_pick_quat, "gripper": 100},
|
||||
{"t": 150, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": 100},
|
||||
{"t": 180, "xyz": bar_xyz + np.array([0.0, 0.0, 0.08]), "quat": right_pick_quat, "gripper": -100},
|
||||
{"t": 240, "xyz": self._action_xyz_for_object_center(bar_xyz + np.array([0.0, 0.0, 0.12]), right_pick_quat, object_offset_local), "quat": right_pick_quat, "gripper": -100},
|
||||
{"t": 320, "xyz": right_reorient_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||
{"t": 460, "xyz": right_wait_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||
{"t": 600, "xyz": right_insert_start_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||
{"t": 690, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||
{"t": 700, "xyz": right_insert_end_xyz, "quat": right_insert_quat, "gripper": -100},
|
||||
{
|
||||
"t": 1,
|
||||
"xyz": init_mocap_pose_right[:3],
|
||||
"quat": init_mocap_pose_right[3:],
|
||||
"gripper": 100,
|
||||
},
|
||||
{
|
||||
"t": 80,
|
||||
"xyz": peg_init_xyz,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": 100,
|
||||
},
|
||||
{
|
||||
"t": 150,
|
||||
"xyz": peg_init_xyz,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": 100,
|
||||
},
|
||||
{
|
||||
"t": 180,
|
||||
"xyz": peg_init_xyz,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": 450,
|
||||
"xyz": peg_init_xyz,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": 550,
|
||||
"xyz": peg_lift_center,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": self.INSERT_START_T,
|
||||
"xyz": peg_lift_center,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": self.INSERT_END_T,
|
||||
"xyz": peg_insert_end_center,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": 750,
|
||||
"xyz": peg_insert_end_center,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
{
|
||||
"t": self.EPISODE_END_T,
|
||||
"xyz": peg_insert_end_center,
|
||||
"quat": right_pick_quat,
|
||||
"gripper": self.RIGHT_PEG_GRIPPER_CLOSED,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -5,7 +5,7 @@ from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.demos.diana_air_insert_policy import TestAirInsertPolicy
|
||||
from roboimi.demos.diana_policy import TestPickAndTransferPolicy
|
||||
import cv2
|
||||
from roboimi.utils.act_ex_utils import sample_air_insert_ring_bar_state, sample_transfer_pose
|
||||
from roboimi.utils.act_ex_utils import sample_air_insert_socket_peg_state, sample_transfer_pose
|
||||
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
||||
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
||||
|
||||
@@ -17,16 +17,18 @@ DATASET_DIR = HOME_PATH + '/dataset'
|
||||
def sample_task_state(task_name):
|
||||
if task_name == 'sim_transfer':
|
||||
return sample_transfer_pose()
|
||||
if task_name == 'sim_air_insert_ring_bar':
|
||||
return sample_air_insert_ring_bar_state()
|
||||
if task_name == 'sim_air_insert_socket_peg':
|
||||
return sample_air_insert_socket_peg_state()
|
||||
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
|
||||
|
||||
|
||||
def make_policy(task_name, inject_noise=False):
|
||||
def make_policy(task_name, inject_noise=False, grasp_strategy=None):
|
||||
if task_name == 'sim_transfer':
|
||||
return TestPickAndTransferPolicy(inject_noise)
|
||||
if task_name == 'sim_air_insert_ring_bar':
|
||||
return TestAirInsertPolicy(inject_noise)
|
||||
if task_name == 'sim_air_insert_socket_peg':
|
||||
if grasp_strategy is None:
|
||||
return TestAirInsertPolicy(inject_noise)
|
||||
return TestAirInsertPolicy(inject_noise, grasp_strategy=grasp_strategy)
|
||||
raise NotImplementedError(f'Unsupported scripted rollout task: {task_name}')
|
||||
|
||||
|
||||
@@ -37,9 +39,9 @@ def main(task_name='sim_transfer'):
|
||||
inject_noise = False
|
||||
|
||||
episode_len = task_cfg['episode_len']
|
||||
camera_names = ['angle', 'r_vis', 'top', 'front']
|
||||
camera_names = ['left_side', 'r_vis', 'top', 'front']
|
||||
image_size = (256, 256)
|
||||
if task_name in {'sim_transfer', 'sim_air_insert_ring_bar'}:
|
||||
if task_name in {'sim_transfer', 'sim_air_insert_socket_peg'}:
|
||||
print(task_name)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -27,7 +27,7 @@ from einops import rearrange
|
||||
|
||||
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
||||
from roboimi.utils.act_ex_utils import (
|
||||
sample_air_insert_ring_bar_state,
|
||||
sample_air_insert_socket_peg_state,
|
||||
sample_transfer_pose,
|
||||
)
|
||||
from roboimi.vla.eval_utils import execute_policy_action
|
||||
@@ -489,8 +489,8 @@ def _close_env(env):
|
||||
|
||||
|
||||
def _sample_task_reset_state(task_name: str):
|
||||
if task_name == 'sim_air_insert_ring_bar':
|
||||
return sample_air_insert_ring_bar_state()
|
||||
if task_name == 'sim_air_insert_socket_peg':
|
||||
return sample_air_insert_socket_peg_state()
|
||||
if 'sim_transfer' in task_name:
|
||||
return sample_transfer_pose()
|
||||
raise NotImplementedError(f'Unsupported eval task reset sampling: {task_name}')
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
import copy as cp
|
||||
import time
|
||||
|
||||
import mujoco as mj
|
||||
import numpy as np
|
||||
|
||||
from roboimi.envs.double_base import DualDianaMed
|
||||
from roboimi.envs.double_pos_ctrl_env import DualDianaMed_Pos_Ctrl
|
||||
|
||||
|
||||
RING_JOINT_NAME = "ring_block_joint"
|
||||
BAR_JOINT_NAME = "bar_block_joint"
|
||||
REQUIRED_TASK_STATE_KEYS = ("ring_pos", "ring_quat", "bar_pos", "bar_quat")
|
||||
RING_GEOM_NAMES = (
|
||||
"ring_block_north",
|
||||
"ring_block_south",
|
||||
"ring_block_east",
|
||||
"ring_block_west",
|
||||
)
|
||||
BAR_GEOM_NAMES = ("bar_block",)
|
||||
SOCKET_JOINT_NAME = "blue_socket_joint"
|
||||
PEG_JOINT_NAME = "red_peg_joint"
|
||||
REQUIRED_TASK_STATE_KEYS = ("socket_pos", "socket_quat", "peg_pos", "peg_quat")
|
||||
SOCKET_GEOM_NAMES = ("socket-1", "socket-2", "socket-3", "socket-4")
|
||||
SOCKET_SUCCESS_GEOM_NAMES = ("pin",)
|
||||
SOCKET_BODY_GEOM_NAMES = SOCKET_GEOM_NAMES + SOCKET_SUCCESS_GEOM_NAMES
|
||||
PEG_GEOM_NAMES = ("red_peg",)
|
||||
LEFT_GRIPPER_GEOM_NAMES = (
|
||||
"l_finger_left",
|
||||
"r_finger_left",
|
||||
@@ -25,6 +21,8 @@ LEFT_GRIPPER_GEOM_NAMES = (
|
||||
"r_fingertip_g0_left",
|
||||
"l_fingerpad_g0_left",
|
||||
"r_fingerpad_g0_left",
|
||||
"l_fingertip_g0_vis_left",
|
||||
"r_fingertip_g0_vis_left",
|
||||
)
|
||||
RIGHT_GRIPPER_GEOM_NAMES = (
|
||||
"l_finger_right",
|
||||
@@ -33,12 +31,10 @@ RIGHT_GRIPPER_GEOM_NAMES = (
|
||||
"r_fingertip_g0_right",
|
||||
"l_fingerpad_g0_right",
|
||||
"r_fingerpad_g0_right",
|
||||
"l_fingertip_g0_vis_right",
|
||||
"r_fingertip_g0_vis_right",
|
||||
)
|
||||
TABLE_GEOM_NAME = "table"
|
||||
RING_APERTURE_HALF_WIDTH = 0.016
|
||||
RING_HALF_THICKNESS = 0.009
|
||||
BAR_HALF_SIZES = np.array([0.045, 0.009, 0.009], dtype=np.float64)
|
||||
SCRIPTED_GRASP_CLOSE_THRESHOLD = 0.0
|
||||
|
||||
|
||||
def _set_free_joint_pose(joint, position, quat):
|
||||
@@ -46,29 +42,29 @@ def _set_free_joint_pose(joint, position, quat):
|
||||
joint.qpos[3:7] = np.asarray(quat, dtype=np.float64)
|
||||
|
||||
|
||||
def set_ring_bar_task_state(mj_data, task_state):
|
||||
def set_socket_peg_task_state(mj_data, task_state):
|
||||
if not isinstance(task_state, dict) or tuple(task_state.keys()) != REQUIRED_TASK_STATE_KEYS:
|
||||
raise ValueError(
|
||||
"task_state must be an ordered dict-like mapping with keys "
|
||||
"ring_pos, ring_quat, bar_pos, bar_quat"
|
||||
"socket_pos, socket_quat, peg_pos, peg_quat"
|
||||
)
|
||||
|
||||
_set_free_joint_pose(
|
||||
mj_data.joint(RING_JOINT_NAME),
|
||||
task_state["ring_pos"],
|
||||
task_state["ring_quat"],
|
||||
mj_data.joint(SOCKET_JOINT_NAME),
|
||||
task_state["socket_pos"],
|
||||
task_state["socket_quat"],
|
||||
)
|
||||
_set_free_joint_pose(
|
||||
mj_data.joint(BAR_JOINT_NAME),
|
||||
task_state["bar_pos"],
|
||||
task_state["bar_quat"],
|
||||
mj_data.joint(PEG_JOINT_NAME),
|
||||
task_state["peg_pos"],
|
||||
task_state["peg_quat"],
|
||||
)
|
||||
|
||||
|
||||
def get_ring_bar_env_state(mj_data):
|
||||
ring_qpos = cp.deepcopy(np.asarray(mj_data.joint(RING_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||
bar_qpos = cp.deepcopy(np.asarray(mj_data.joint(BAR_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||
return np.concatenate([ring_qpos, bar_qpos], dtype=np.float64)
|
||||
def get_socket_peg_env_state(mj_data):
|
||||
socket_qpos = cp.deepcopy(np.asarray(mj_data.joint(SOCKET_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||
peg_qpos = cp.deepcopy(np.asarray(mj_data.joint(PEG_JOINT_NAME).qpos[:7], dtype=np.float64))
|
||||
return np.concatenate([socket_qpos, peg_qpos], dtype=np.float64)
|
||||
|
||||
|
||||
def _normalize_contact_pairs(contact_pairs):
|
||||
@@ -87,91 +83,29 @@ def _object_is_airborne(contact_set, object_geom_names):
|
||||
return not _has_any_object_contact(contact_set, object_geom_names, (TABLE_GEOM_NAME,))
|
||||
|
||||
|
||||
def _quat_to_rotation_matrix(quat):
|
||||
quat = np.asarray(quat, dtype=np.float64)
|
||||
quat /= np.linalg.norm(quat)
|
||||
w, x, y, z = quat
|
||||
return np.array(
|
||||
[
|
||||
[1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - z * w), 2.0 * (x * z + y * w)],
|
||||
[2.0 * (x * y + z * w), 1.0 - 2.0 * (x * x + z * z), 2.0 * (y * z - x * w)],
|
||||
[2.0 * (x * z - y * w), 2.0 * (y * z + x * w), 1.0 - 2.0 * (x * x + y * y)],
|
||||
],
|
||||
dtype=np.float64,
|
||||
)
|
||||
def peg_inserted_into_socket(contact_pairs):
|
||||
contact_set = _normalize_contact_pairs(contact_pairs)
|
||||
return frozenset((PEG_GEOM_NAMES[0], SOCKET_SUCCESS_GEOM_NAMES[0])) in contact_set
|
||||
|
||||
|
||||
def _quat_multiply(lhs, rhs):
|
||||
lhs = np.asarray(lhs, dtype=np.float64)
|
||||
rhs = np.asarray(rhs, dtype=np.float64)
|
||||
w1, x1, y1, z1 = lhs
|
||||
w2, x2, y2, z2 = rhs
|
||||
return np.array(
|
||||
[
|
||||
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
||||
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
||||
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
||||
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
|
||||
],
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
|
||||
def _quat_inverse(quat):
|
||||
quat = np.asarray(quat, dtype=np.float64)
|
||||
norm_sq = float(np.dot(quat, quat))
|
||||
return np.array([quat[0], -quat[1], -quat[2], -quat[3]], dtype=np.float64) / norm_sq
|
||||
|
||||
|
||||
def _split_env_state(env_state):
|
||||
env_state = np.asarray(env_state, dtype=np.float64)
|
||||
if env_state.shape != (14,):
|
||||
raise ValueError(f"env_state must have shape (14,), got {env_state.shape}")
|
||||
return (
|
||||
env_state[:3],
|
||||
env_state[3:7],
|
||||
env_state[7:10],
|
||||
env_state[10:14],
|
||||
)
|
||||
|
||||
|
||||
def bar_fully_inserted_through_ring(env_state):
|
||||
ring_pos, ring_quat, bar_pos, bar_quat = _split_env_state(env_state)
|
||||
ring_rot = _quat_to_rotation_matrix(ring_quat)
|
||||
bar_rot = _quat_to_rotation_matrix(bar_quat)
|
||||
|
||||
bar_center_in_ring = ring_rot.T @ (bar_pos - ring_pos)
|
||||
bar_rot_in_ring = ring_rot.T @ bar_rot
|
||||
projected_half_extents = np.abs(bar_rot_in_ring) @ BAR_HALF_SIZES
|
||||
|
||||
spans_ring_thickness = (
|
||||
bar_center_in_ring[2] - projected_half_extents[2] <= -RING_HALF_THICKNESS
|
||||
and bar_center_in_ring[2] + projected_half_extents[2] >= RING_HALF_THICKNESS
|
||||
)
|
||||
fits_aperture = (
|
||||
abs(bar_center_in_ring[0]) + projected_half_extents[0] <= RING_APERTURE_HALF_WIDTH
|
||||
and abs(bar_center_in_ring[1]) + projected_half_extents[1] <= RING_APERTURE_HALF_WIDTH
|
||||
)
|
||||
return bool(spans_ring_thickness and fits_aperture)
|
||||
|
||||
|
||||
def compute_air_insert_reward(contact_pairs, env_state):
|
||||
def compute_air_insert_reward(contact_pairs, env_state=None):
|
||||
del env_state # kept for API compatibility with rollout/eval code paths
|
||||
contact_set = _normalize_contact_pairs(contact_pairs)
|
||||
reward = 0
|
||||
|
||||
if _has_any_object_contact(contact_set, RING_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES):
|
||||
if _has_any_object_contact(contact_set, SOCKET_GEOM_NAMES, LEFT_GRIPPER_GEOM_NAMES):
|
||||
reward += 1
|
||||
if _has_any_object_contact(contact_set, BAR_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES):
|
||||
if _has_any_object_contact(contact_set, PEG_GEOM_NAMES, RIGHT_GRIPPER_GEOM_NAMES):
|
||||
reward += 1
|
||||
|
||||
ring_airborne = _object_is_airborne(contact_set, RING_GEOM_NAMES)
|
||||
bar_airborne = _object_is_airborne(contact_set, BAR_GEOM_NAMES)
|
||||
if ring_airborne:
|
||||
socket_airborne = _object_is_airborne(contact_set, SOCKET_BODY_GEOM_NAMES)
|
||||
peg_airborne = _object_is_airborne(contact_set, PEG_GEOM_NAMES)
|
||||
if socket_airborne:
|
||||
reward += 1
|
||||
if bar_airborne:
|
||||
if peg_airborne:
|
||||
reward += 1
|
||||
|
||||
if ring_airborne and bar_airborne and bar_fully_inserted_through_ring(env_state):
|
||||
if socket_airborne and peg_airborne and peg_inserted_into_socket(contact_pairs):
|
||||
reward += 1
|
||||
|
||||
return reward
|
||||
@@ -181,33 +115,19 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_reward = 5
|
||||
self._scripted_ring_grasped = False
|
||||
self._scripted_bar_grasped = False
|
||||
self._scripted_ring_pos_offset_local = None
|
||||
self._scripted_bar_pos_offset_local = None
|
||||
self._scripted_ring_quat_offset = None
|
||||
self._scripted_bar_quat_offset = None
|
||||
self._air_insert_step_count = 0
|
||||
|
||||
def reset(self, task_state):
|
||||
self._scripted_ring_grasped = False
|
||||
self._scripted_bar_grasped = False
|
||||
self._scripted_ring_pos_offset_local = None
|
||||
self._scripted_bar_pos_offset_local = None
|
||||
self._scripted_ring_quat_offset = None
|
||||
self._scripted_bar_quat_offset = None
|
||||
self._air_insert_step_count = 0
|
||||
set_ring_bar_task_state(self.mj_data, task_state)
|
||||
set_socket_peg_task_state(self.mj_data, task_state)
|
||||
DualDianaMed.reset(self)
|
||||
self.top = None
|
||||
self.angle = None
|
||||
self.left_side = None
|
||||
self.r_vis = None
|
||||
self.front = None
|
||||
self.cam_flage = True
|
||||
while self.cam_flage:
|
||||
if (
|
||||
type(self.top) == type(None)
|
||||
or type(self.angle) == type(None)
|
||||
or type(self.left_side) == type(None)
|
||||
or type(self.r_vis) == type(None)
|
||||
or type(self.front) == type(None)
|
||||
):
|
||||
@@ -217,76 +137,11 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
||||
|
||||
def step(self, action=np.zeros(16)):
|
||||
super().step(action)
|
||||
self._update_scripted_grasped_objects(action)
|
||||
self.rew = self._get_reward()
|
||||
self.obs = self._get_obs()
|
||||
self._air_insert_step_count += 1
|
||||
|
||||
def _update_scripted_grasped_objects(self, action):
|
||||
if (
|
||||
action[-2] < SCRIPTED_GRASP_CLOSE_THRESHOLD
|
||||
and self._air_insert_step_count >= 180
|
||||
and not self._scripted_ring_grasped
|
||||
):
|
||||
self._scripted_ring_grasped = True
|
||||
self._attach_scripted_object(
|
||||
object_joint_name=RING_JOINT_NAME,
|
||||
ee_pos=action[:3],
|
||||
ee_quat=action[3:7],
|
||||
pos_attr="_scripted_ring_pos_offset_local",
|
||||
quat_attr="_scripted_ring_quat_offset",
|
||||
)
|
||||
if (
|
||||
action[-1] < SCRIPTED_GRASP_CLOSE_THRESHOLD
|
||||
and self._air_insert_step_count >= 180
|
||||
and not self._scripted_bar_grasped
|
||||
):
|
||||
self._scripted_bar_grasped = True
|
||||
self._attach_scripted_object(
|
||||
object_joint_name=BAR_JOINT_NAME,
|
||||
ee_pos=action[7:10],
|
||||
ee_quat=action[10:14],
|
||||
pos_attr="_scripted_bar_pos_offset_local",
|
||||
quat_attr="_scripted_bar_quat_offset",
|
||||
)
|
||||
|
||||
if self._scripted_ring_grasped:
|
||||
self._update_scripted_object_pose(
|
||||
object_joint_name=RING_JOINT_NAME,
|
||||
ee_pos=action[:3],
|
||||
ee_quat=action[3:7],
|
||||
pos_offset_local=self._scripted_ring_pos_offset_local,
|
||||
quat_offset=self._scripted_ring_quat_offset,
|
||||
)
|
||||
if self._scripted_bar_grasped:
|
||||
self._update_scripted_object_pose(
|
||||
object_joint_name=BAR_JOINT_NAME,
|
||||
ee_pos=action[7:10],
|
||||
ee_quat=action[10:14],
|
||||
pos_offset_local=self._scripted_bar_pos_offset_local,
|
||||
quat_offset=self._scripted_bar_quat_offset,
|
||||
)
|
||||
if self._scripted_ring_grasped or self._scripted_bar_grasped:
|
||||
mj.mj_forward(self.mj_model, self.mj_data)
|
||||
|
||||
def _attach_scripted_object(self, object_joint_name, ee_pos, ee_quat, pos_attr, quat_attr):
|
||||
ee_pos = np.asarray(ee_pos, dtype=np.float64)
|
||||
ee_quat = np.asarray(ee_quat, dtype=np.float64)
|
||||
object_qpos = np.asarray(self.mj_data.joint(object_joint_name).qpos[:7], dtype=np.float64)
|
||||
ee_rot = _quat_to_rotation_matrix(ee_quat)
|
||||
setattr(self, pos_attr, ee_rot.T @ (object_qpos[:3] - ee_pos))
|
||||
setattr(self, quat_attr, _quat_multiply(_quat_inverse(ee_quat), object_qpos[3:7]))
|
||||
|
||||
def _update_scripted_object_pose(self, object_joint_name, ee_pos, ee_quat, pos_offset_local, quat_offset):
|
||||
ee_pos = np.asarray(ee_pos, dtype=np.float64)
|
||||
ee_quat = np.asarray(ee_quat, dtype=np.float64)
|
||||
ee_rot = _quat_to_rotation_matrix(ee_quat)
|
||||
object_pos = ee_pos + ee_rot @ np.asarray(pos_offset_local, dtype=np.float64)
|
||||
object_quat = _quat_multiply(ee_quat, quat_offset)
|
||||
_set_free_joint_pose(self.mj_data.joint(object_joint_name), object_pos, object_quat)
|
||||
|
||||
def get_env_state(self):
|
||||
return get_ring_bar_env_state(self.mj_data)
|
||||
return get_socket_peg_env_state(self.mj_data)
|
||||
|
||||
def _get_reward(self):
|
||||
contact_pairs = []
|
||||
@@ -296,8 +151,4 @@ class DualDianaMed_Air_Insert(DualDianaMed_Pos_Ctrl):
|
||||
contact_pairs.append(
|
||||
(self.getID2Name("geom", geom1), self.getID2Name("geom", geom2))
|
||||
)
|
||||
if self._scripted_ring_grasped:
|
||||
contact_pairs.append(("ring_block_south", "l_fingertip_g0_left"))
|
||||
if self._scripted_bar_grasped:
|
||||
contact_pairs.append(("bar_block", "r_fingertip_g0_right"))
|
||||
return compute_air_insert_reward(contact_pairs, self.get_env_state())
|
||||
|
||||
@@ -52,7 +52,7 @@ class DualDianaMed(MujocoEnv):
|
||||
self.r_vis = None
|
||||
self.l_vis = None
|
||||
self.top = None
|
||||
self.angle = None
|
||||
self.left_side = None
|
||||
self.front = None
|
||||
self.obs = None
|
||||
|
||||
@@ -166,7 +166,7 @@ class DualDianaMed(MujocoEnv):
|
||||
obs['action'] = self.compute_qpos
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = self.top
|
||||
obs['images']['angle'] = self.angle
|
||||
obs['images']['left_side'] = self.left_side
|
||||
obs['images']['r_vis'] = self.r_vis
|
||||
obs['images']['l_vis'] = self.l_vis
|
||||
obs['images']['front'] = self.front
|
||||
@@ -176,7 +176,7 @@ class DualDianaMed(MujocoEnv):
|
||||
obs = collections.OrderedDict()
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = self.top
|
||||
obs['images']['angle'] = self.angle
|
||||
obs['images']['left_side'] = self.left_side
|
||||
obs['images']['r_vis'] = self.r_vis
|
||||
obs['images']['l_vis'] = self.l_vis
|
||||
obs['images']['front'] = self.front
|
||||
@@ -199,8 +199,8 @@ class DualDianaMed(MujocoEnv):
|
||||
def cam_view(self):
|
||||
if self.cam == 'top':
|
||||
return self.top
|
||||
elif self.cam == 'angle':
|
||||
return self.angle
|
||||
elif self.cam == 'left_side':
|
||||
return self.left_side
|
||||
elif self.cam == 'r_vis':
|
||||
return self.r_vis
|
||||
elif self.cam == 'l_vis':
|
||||
@@ -226,9 +226,9 @@ class DualDianaMed(MujocoEnv):
|
||||
img_renderer.update_scene(self.mj_data,camera="top")
|
||||
self.top = img_renderer.render()
|
||||
self.top = self.top[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="angle")
|
||||
self.angle = img_renderer.render()
|
||||
self.angle = self.angle[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="left_side")
|
||||
self.left_side = img_renderer.render()
|
||||
self.left_side = self.left_side[:, :, ::-1]
|
||||
img_renderer.update_scene(self.mj_data,camera="front")
|
||||
self.front = img_renderer.render()
|
||||
self.front = self.front[:, :, ::-1]
|
||||
|
||||
@@ -73,14 +73,14 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
self.mj_data.joint('red_box_joint').qpos[6] = 0.0
|
||||
super().reset()
|
||||
self.top = None
|
||||
self.angle = None
|
||||
self.left_side = None
|
||||
self.r_vis = None
|
||||
self.front = None
|
||||
self.cam_flage = True
|
||||
t=0
|
||||
while self.cam_flage:
|
||||
if(type(self.top)==type(None)
|
||||
or type(self.angle)==type(None)
|
||||
or type(self.left_side)==type(None)
|
||||
or type(self.r_vis)==type(None)
|
||||
or type(self.front)==type(None)):
|
||||
time.sleep(0.001)
|
||||
@@ -134,16 +134,16 @@ class DualDianaMed_Pos_Ctrl(DualDianaMed):
|
||||
|
||||
|
||||
def make_sim_env(task_name, headless=False):
|
||||
if task_name == 'sim_air_insert_ring_bar':
|
||||
from roboimi.assets.robots.diana_med import BiDianaMedRingBar
|
||||
if task_name == 'sim_air_insert_socket_peg':
|
||||
from roboimi.assets.robots.diana_med import BiDianaMedSocketPeg
|
||||
from roboimi.envs.double_air_insert_env import DualDianaMed_Air_Insert
|
||||
|
||||
env = DualDianaMed_Air_Insert(
|
||||
robot=BiDianaMedRingBar(),
|
||||
robot=BiDianaMedSocketPeg(),
|
||||
is_render=not headless,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view='angle'
|
||||
cam_view='left_side'
|
||||
)
|
||||
return env
|
||||
if 'sim_transfer' in task_name:
|
||||
@@ -153,7 +153,7 @@ def make_sim_env(task_name, headless=False):
|
||||
is_render=not headless,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view='angle'
|
||||
cam_view='left_side'
|
||||
)
|
||||
return env
|
||||
else:
|
||||
|
||||
@@ -39,19 +39,20 @@ def sample_transfer_pose():
|
||||
return box_position
|
||||
|
||||
|
||||
def sample_air_insert_ring_bar_state():
|
||||
ring_position = np.random.uniform(
|
||||
low=np.array([-0.20, 0.70, 0.47], dtype=np.float32),
|
||||
high=np.array([-0.05, 1.00, 0.47], dtype=np.float32),
|
||||
def sample_air_insert_socket_peg_state():
|
||||
socket_position = np.random.uniform(
|
||||
low=np.array([-0.14, 0.89, 0.472], dtype=np.float32),
|
||||
high=np.array([-0.10, 0.94, 0.472], dtype=np.float32),
|
||||
)
|
||||
bar_position = np.random.uniform(
|
||||
low=np.array([0.05, 0.70, 0.47], dtype=np.float32),
|
||||
high=np.array([0.20, 1.00, 0.47], dtype=np.float32),
|
||||
peg_position = np.random.uniform(
|
||||
low=np.array([0.10, 0.85, 0.46], dtype=np.float32),
|
||||
high=np.array([0.16, 0.94, 0.46], dtype=np.float32),
|
||||
)
|
||||
fixed_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
socket_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
peg_quat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
|
||||
return {
|
||||
"ring_pos": ring_position.astype(np.float32, copy=False),
|
||||
"ring_quat": fixed_quat.copy(),
|
||||
"bar_pos": bar_position.astype(np.float32, copy=False),
|
||||
"bar_quat": fixed_quat.copy(),
|
||||
"socket_pos": socket_position.astype(np.float32, copy=False),
|
||||
"socket_quat": socket_quat,
|
||||
"peg_pos": peg_position.astype(np.float32, copy=False),
|
||||
"peg_quat": peg_quat,
|
||||
}
|
||||
|
||||
@@ -23,10 +23,10 @@ SIM_TASK_CONFIGS = {
|
||||
'camera_names': ['top','r_vis','front'],
|
||||
'xml_dir': HOME_PATH + '/assets'
|
||||
},
|
||||
'sim_air_insert_ring_bar': {
|
||||
'dataset_dir': DATASET_DIR + '/sim_air_insert_ring_bar',
|
||||
'sim_air_insert_socket_peg': {
|
||||
'dataset_dir': DATASET_DIR + '/sim_air_insert_socket_peg',
|
||||
'num_episodes': 20,
|
||||
'episode_len': 700,
|
||||
'episode_len': 1000,
|
||||
'camera_names': ['top', 'r_vis', 'front'],
|
||||
'xml_dir': HOME_PATH + '/assets'
|
||||
},
|
||||
@@ -59,13 +59,3 @@ PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) /
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN((x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE))
|
||||
PUPPET_POS2JOINT = lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN((x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE))
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE)/2
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pathlib
|
||||
import unittest
|
||||
from unittest import mock
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -9,83 +12,80 @@ from roboimi.utils import act_ex_utils
|
||||
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
||||
|
||||
|
||||
class AirInsertTaskRegistrationTest(unittest.TestCase):
|
||||
def test_sim_task_configs_registers_air_insert_ring_bar(self):
|
||||
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
|
||||
TASK_NAME = "sim_air_insert_socket_peg"
|
||||
|
||||
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
|
||||
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
||||
|
||||
class AirInsertTaskRegistrationTest(unittest.TestCase):
|
||||
def test_sim_task_configs_registers_air_insert_socket_peg(self):
|
||||
self.assertIn(TASK_NAME, SIM_TASK_CONFIGS)
|
||||
self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
|
||||
self.assertGreaterEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 1000)
|
||||
self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg"))
|
||||
|
||||
def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self):
|
||||
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
|
||||
self.assertIsNotNone(
|
||||
sampler,
|
||||
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
||||
"Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()",
|
||||
)
|
||||
self.assertFalse(
|
||||
hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"),
|
||||
"air insert sampler should use socket/peg naming after the task rename",
|
||||
)
|
||||
|
||||
task_state = sampler()
|
||||
|
||||
self.assertEqual(
|
||||
list(task_state.keys()),
|
||||
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
||||
["socket_pos", "socket_quat", "peg_pos", "peg_quat"],
|
||||
)
|
||||
self.assertEqual(task_state["ring_pos"].shape, (3,))
|
||||
self.assertEqual(task_state["ring_quat"].shape, (4,))
|
||||
self.assertEqual(task_state["bar_pos"].shape, (3,))
|
||||
self.assertEqual(task_state["bar_quat"].shape, (4,))
|
||||
self.assertEqual(task_state["socket_pos"].shape, (3,))
|
||||
self.assertEqual(task_state["socket_quat"].shape, (4,))
|
||||
self.assertEqual(task_state["peg_pos"].shape, (3,))
|
||||
self.assertEqual(task_state["peg_quat"].shape, (4,))
|
||||
|
||||
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
|
||||
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
||||
self.assertIsNotNone(
|
||||
sampler,
|
||||
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
||||
)
|
||||
def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self):
|
||||
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
|
||||
self.assertIsNotNone(sampler)
|
||||
|
||||
task_state = sampler()
|
||||
|
||||
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
||||
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
||||
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
|
||||
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
|
||||
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
|
||||
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
|
||||
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
|
||||
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
|
||||
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
|
||||
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
|
||||
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
|
||||
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
|
||||
|
||||
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
|
||||
try:
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
except Exception as exc:
|
||||
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
|
||||
np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
|
||||
np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
|
||||
self.assertGreaterEqual(task_state["socket_pos"][0], -0.20)
|
||||
self.assertLessEqual(task_state["socket_pos"][0], -0.05)
|
||||
self.assertGreaterEqual(task_state["socket_pos"][1], 0.70)
|
||||
self.assertLessEqual(task_state["socket_pos"][1], 1.00)
|
||||
self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472)
|
||||
self.assertGreaterEqual(task_state["peg_pos"][0], 0.05)
|
||||
self.assertLessEqual(task_state["peg_pos"][0], 0.20)
|
||||
self.assertGreaterEqual(task_state["peg_pos"][1], 0.70)
|
||||
self.assertLessEqual(task_state["peg_pos"][1], 1.00)
|
||||
self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46)
|
||||
|
||||
def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
||||
self.assertIsNotNone(
|
||||
air_insert_cls,
|
||||
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
|
||||
)
|
||||
self.assertIsNotNone(air_insert_cls)
|
||||
|
||||
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
|
||||
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
|
||||
socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None)
|
||||
self.assertIsNotNone(
|
||||
ring_bar_robot_cls,
|
||||
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
|
||||
socket_peg_robot_cls,
|
||||
"Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg",
|
||||
)
|
||||
|
||||
fake_env = object()
|
||||
with mock.patch.object(
|
||||
diana_med,
|
||||
"BiDianaMedRingBar",
|
||||
"BiDianaMedSocketPeg",
|
||||
return_value="robot",
|
||||
), mock.patch.object(
|
||||
air_insert_env,
|
||||
"DualDianaMed_Air_Insert",
|
||||
return_value=fake_env,
|
||||
) as env_cls:
|
||||
try:
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
except Exception as exc:
|
||||
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
|
||||
env = make_sim_env(TASK_NAME, headless=True)
|
||||
|
||||
self.assertIs(env, fake_env)
|
||||
env_cls.assert_called_once_with(
|
||||
@@ -93,21 +93,36 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
|
||||
is_render=False,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view="angle",
|
||||
cam_view="left_side",
|
||||
)
|
||||
|
||||
def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self):
|
||||
xml_path = (
|
||||
pathlib.Path(__file__).resolve().parents[1]
|
||||
/ "roboimi/assets/models/manipulators/DianaMed/table_square.xml"
|
||||
)
|
||||
root = ET.parse(xml_path).getroot()
|
||||
cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")}
|
||||
|
||||
self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera")
|
||||
self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera")
|
||||
left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ")
|
||||
self.assertLess(float(left_side_pos[0]), 0.0)
|
||||
self.assertEqual(cameras["left_side"].get("mode"), "targetbody")
|
||||
self.assertEqual(cameras["left_side"].get("target"), "table")
|
||||
|
||||
|
||||
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
|
||||
def test_set_socket_peg_task_state_writes_free_joint_qpos(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
|
||||
setter = getattr(air_insert_env, "set_socket_peg_task_state", None)
|
||||
self.assertIsNotNone(
|
||||
setter,
|
||||
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
|
||||
"Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state",
|
||||
)
|
||||
|
||||
ring_qpos = np.zeros(7, dtype=np.float64)
|
||||
bar_qpos = np.zeros(7, dtype=np.float64)
|
||||
socket_qpos = np.zeros(7, dtype=np.float64)
|
||||
peg_qpos = np.zeros(7, dtype=np.float64)
|
||||
|
||||
class _FakeJoint:
|
||||
def __init__(self, qpos):
|
||||
@@ -115,40 +130,40 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||
|
||||
class _FakeData:
|
||||
def joint(self, name):
|
||||
if name == "ring_block_joint":
|
||||
return _FakeJoint(ring_qpos)
|
||||
if name == "bar_block_joint":
|
||||
return _FakeJoint(bar_qpos)
|
||||
if name == "blue_socket_joint":
|
||||
return _FakeJoint(socket_qpos)
|
||||
if name == "red_peg_joint":
|
||||
return _FakeJoint(peg_qpos)
|
||||
raise AssertionError(f"Unexpected joint name: {name}")
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64),
|
||||
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
"peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64),
|
||||
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
}
|
||||
|
||||
setter(_FakeData(), task_state)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
ring_qpos,
|
||||
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
socket_qpos,
|
||||
np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
bar_qpos,
|
||||
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
peg_qpos,
|
||||
np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
)
|
||||
|
||||
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
|
||||
def test_get_socket_peg_env_state_returns_stable_14d_vector(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
|
||||
getter = getattr(air_insert_env, "get_socket_peg_env_state", None)
|
||||
self.assertIsNotNone(
|
||||
getter,
|
||||
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
|
||||
"Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state",
|
||||
)
|
||||
|
||||
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
|
||||
class _FakeJoint:
|
||||
def __init__(self, qpos):
|
||||
@@ -156,10 +171,10 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||
|
||||
class _FakeData:
|
||||
def joint(self, name):
|
||||
if name == "ring_block_joint":
|
||||
return _FakeJoint(ring_qpos)
|
||||
if name == "bar_block_joint":
|
||||
return _FakeJoint(bar_qpos)
|
||||
if name == "blue_socket_joint":
|
||||
return _FakeJoint(socket_qpos)
|
||||
if name == "red_peg_joint":
|
||||
return _FakeJoint(peg_qpos)
|
||||
raise AssertionError(f"Unexpected joint name: {name}")
|
||||
|
||||
env_state = getter(_FakeData())
|
||||
@@ -168,38 +183,78 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||
np.testing.assert_array_equal(
|
||||
env_state,
|
||||
np.array(
|
||||
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
|
||||
[-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0],
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
|
||||
def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
||||
self.assertIsNotNone(env_cls)
|
||||
|
||||
source = inspect.getsource(env_cls)
|
||||
|
||||
self.assertNotIn("_update_scripted_grasped_objects", source)
|
||||
self.assertNotIn("_scripted_", source)
|
||||
self.assertNotIn("_stabilize_ring_grasp", source)
|
||||
self.assertNotIn("_ring_grasp_locked", source)
|
||||
get_reward_source = inspect.getsource(env_cls._get_reward)
|
||||
self.assertNotIn("ring_block", get_reward_source)
|
||||
self.assertNotIn("bar_block", get_reward_source)
|
||||
|
||||
def test_socket_peg_xml_defines_active_socket_and_peg_objects(self):
|
||||
asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed"
|
||||
xml_path = asset_dir / "socket_peg_objects.xml"
|
||||
self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml")
|
||||
self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed")
|
||||
|
||||
root = ET.parse(xml_path).getroot()
|
||||
body_names = {body.attrib.get("name") for body in root.findall(".//body")}
|
||||
geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")}
|
||||
joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")}
|
||||
|
||||
self.assertIn("socket", body_names)
|
||||
self.assertIn("peg", body_names)
|
||||
self.assertNotIn("ring_block", body_names)
|
||||
self.assertNotIn("bar_block", body_names)
|
||||
self.assertIn("blue_socket_joint", joint_names)
|
||||
self.assertIn("red_peg_joint", joint_names)
|
||||
for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"):
|
||||
self.assertIn(geom_name, geom_names)
|
||||
|
||||
def test_socket_peg_wrapper_includes_socket_peg_objects(self):
|
||||
xml_path = (
|
||||
pathlib.Path(__file__).resolve().parents[1]
|
||||
/ "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml"
|
||||
)
|
||||
self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name")
|
||||
root = ET.parse(xml_path).getroot()
|
||||
includes = [include.attrib.get("file") for include in root.findall(".//include")]
|
||||
self.assertIn("./socket_peg_objects.xml", includes)
|
||||
self.assertNotIn("./ring_bar_objects.xml", includes)
|
||||
|
||||
|
||||
class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _make_env_state(
|
||||
ring_pos=(0.0, 0.0, 0.50),
|
||||
ring_quat=(1.0, 0.0, 0.0, 0.0),
|
||||
bar_pos=(0.0, 0.0, 0.50),
|
||||
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
|
||||
socket_pos=(0.0, 0.0, 0.472),
|
||||
socket_quat=(1.0, 0.0, 0.0, 0.0),
|
||||
peg_pos=(0.0, 0.0, 0.46),
|
||||
peg_quat=(1.0, 0.0, 0.0, 0.0),
|
||||
):
|
||||
return np.array(
|
||||
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
|
||||
dtype=np.float64,
|
||||
)
|
||||
return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64)
|
||||
|
||||
def test_compute_air_insert_reward_counts_left_contact_stage(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||
self.assertIsNotNone(
|
||||
reward_fn,
|
||||
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
|
||||
)
|
||||
self.assertIsNotNone(reward_fn)
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("ring_block_north", "l_finger_left"),
|
||||
("ring_block_north", "table"),
|
||||
("bar_block", "table"),
|
||||
("socket-1", "l_finger_left"),
|
||||
("socket-1", "table"),
|
||||
("red_peg", "table"),
|
||||
],
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
@@ -212,10 +267,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("ring_block_north", "l_finger_left"),
|
||||
("bar_block", "l_finger_right"),
|
||||
("ring_block_north", "table"),
|
||||
("bar_block", "table"),
|
||||
("socket-1", "l_finger_left"),
|
||||
("red_peg", "l_finger_right"),
|
||||
("socket-1", "table"),
|
||||
("red_peg", "table"),
|
||||
],
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
@@ -228,47 +283,43 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("ring_block_north", "l_finger_left"),
|
||||
("bar_block", "l_finger_right"),
|
||||
("socket-1", "l_finger_left"),
|
||||
("red_peg", "l_finger_right"),
|
||||
],
|
||||
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
|
||||
self.assertEqual(reward, 4)
|
||||
|
||||
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
|
||||
def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("socket-3", "r_fingertip_g0_vis_left"),
|
||||
("red_peg", "l_fingertip_g0_vis_right"),
|
||||
],
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
reward,
|
||||
4,
|
||||
"visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts",
|
||||
)
|
||||
|
||||
def test_peg_inserted_into_socket_uses_pin_contact(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None)
|
||||
self.assertIsNotNone(
|
||||
success_fn,
|
||||
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
|
||||
"Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
success_fn(
|
||||
self._make_env_state(),
|
||||
)
|
||||
)
|
||||
|
||||
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||
|
||||
self.assertFalse(
|
||||
success_fn(
|
||||
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
||||
)
|
||||
)
|
||||
|
||||
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
||||
|
||||
self.assertFalse(
|
||||
success_fn(
|
||||
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
|
||||
)
|
||||
)
|
||||
self.assertTrue(success_fn([("red_peg", "pin")]))
|
||||
self.assertTrue(success_fn([("pin", "red_peg")]))
|
||||
self.assertFalse(success_fn([("red_peg", "socket-1")]))
|
||||
|
||||
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
@@ -276,9 +327,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("ring_block_north", "l_finger_left"),
|
||||
("bar_block", "l_finger_right"),
|
||||
("ring_block_north", "table"),
|
||||
("socket-1", "l_finger_left"),
|
||||
("red_peg", "l_finger_right"),
|
||||
("socket-1", "table"),
|
||||
("red_peg", "pin"),
|
||||
],
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
@@ -291,8 +343,9 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
|
||||
reward = reward_fn(
|
||||
contact_pairs=[
|
||||
("ring_block_north", "l_finger_left"),
|
||||
("bar_block", "l_finger_right"),
|
||||
("socket-1", "l_finger_left"),
|
||||
("red_peg", "l_finger_right"),
|
||||
("red_peg", "pin"),
|
||||
],
|
||||
env_state=self._make_env_state(),
|
||||
)
|
||||
@@ -301,41 +354,129 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
||||
|
||||
|
||||
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def _canonical_task_state():
|
||||
return {
|
||||
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32),
|
||||
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32),
|
||||
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
def test_air_insert_policy_emits_valid_16d_action(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(
|
||||
policy_cls,
|
||||
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
|
||||
)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
|
||||
policy = policy_cls(inject_noise=False)
|
||||
action = policy.predict(task_state, 0)
|
||||
|
||||
self.assertEqual(action.shape, (16,))
|
||||
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
|
||||
|
||||
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
|
||||
def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = self._canonical_task_state()
|
||||
policy = policy_cls(inject_noise=False)
|
||||
policy.generate_trajectory(task_state)
|
||||
|
||||
start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T)
|
||||
end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T)
|
||||
|
||||
self.assertLess(
|
||||
end_waypoint["xyz"][0],
|
||||
start_waypoint["xyz"][0] - 0.10,
|
||||
"front-view right-to-left peg insertion should decrease world x substantially",
|
||||
)
|
||||
self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02)
|
||||
expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168)
|
||||
self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02)
|
||||
self.assertGreater(float(start_waypoint["xyz"][2]), 0.70)
|
||||
|
||||
def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = {
|
||||
"socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32),
|
||||
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32),
|
||||
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
policy = policy_cls(inject_noise=False)
|
||||
policy.generate_trajectory(task_state)
|
||||
left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180)
|
||||
right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180)
|
||||
action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11)
|
||||
expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset])
|
||||
expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01])
|
||||
|
||||
np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6)
|
||||
np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6)
|
||||
self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket")
|
||||
self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg")
|
||||
|
||||
def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
base_state = {
|
||||
"socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32),
|
||||
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32),
|
||||
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
shifted_state = dict(base_state)
|
||||
shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32)
|
||||
|
||||
base_policy = policy_cls(inject_noise=False)
|
||||
base_policy.generate_trajectory(base_state)
|
||||
shifted_policy = policy_cls(inject_noise=False)
|
||||
shifted_policy.generate_trajectory(shifted_state)
|
||||
|
||||
base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450)
|
||||
shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450)
|
||||
np.testing.assert_allclose(
|
||||
base_hold["xyz"][:2],
|
||||
base_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
|
||||
atol=1e-6,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
shifted_hold["xyz"][:2],
|
||||
shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = self._canonical_task_state()
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]):
|
||||
action = policy.predict(task_state, step)
|
||||
self.assertEqual(action.shape, (16,))
|
||||
|
||||
def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self):
|
||||
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
|
||||
sampler_fn = getattr(rollout_module, "sample_task_state", None)
|
||||
policy_factory = getattr(rollout_module, "make_policy", None)
|
||||
self.assertIsNotNone(
|
||||
sampler_fn,
|
||||
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
|
||||
)
|
||||
self.assertIsNotNone(
|
||||
policy_factory,
|
||||
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
|
||||
)
|
||||
self.assertIsNotNone(sampler_fn)
|
||||
self.assertIsNotNone(policy_factory)
|
||||
|
||||
task_state = sampler_fn("sim_air_insert_ring_bar")
|
||||
self.assertEqual(
|
||||
list(task_state.keys()),
|
||||
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
||||
)
|
||||
task_state = sampler_fn(TASK_NAME)
|
||||
self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"])
|
||||
|
||||
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
|
||||
policy = policy_factory(TASK_NAME, inject_noise=False)
|
||||
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
|
||||
|
||||
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
|
||||
@@ -343,8 +484,8 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
|
||||
env = make_sim_env(TASK_NAME, headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
try:
|
||||
@@ -363,115 +504,6 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
def is_cross_arm_pair(a, b):
|
||||
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
|
||||
|
||||
try:
|
||||
env.reset(task_state)
|
||||
for step in range(460):
|
||||
action = policy.predict(task_state, step)
|
||||
env.step(action)
|
||||
pairs = []
|
||||
for i in range(env.mj_data.ncon):
|
||||
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
|
||||
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
|
||||
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
|
||||
pairs.append((geom1, geom2))
|
||||
self.assertFalse(
|
||||
pairs,
|
||||
f"cross-arm contact detected at step {step}: {pairs[:5]}",
|
||||
)
|
||||
finally:
|
||||
env.exit_flag = True
|
||||
cam_thread = getattr(env, "cam_thread", None)
|
||||
if cam_thread is not None:
|
||||
cam_thread.join(timeout=1.0)
|
||||
viewer = getattr(env, "viewer", None)
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
|
||||
try:
|
||||
env.reset(task_state)
|
||||
for step in range(400):
|
||||
action = policy.predict(task_state, step)
|
||||
env.step(action)
|
||||
ring_z = float(env.get_env_state()[2])
|
||||
self.assertGreater(
|
||||
ring_z,
|
||||
0.55,
|
||||
f"ring dropped before hold phase completed, final z={ring_z:.4f}",
|
||||
)
|
||||
finally:
|
||||
env.exit_flag = True
|
||||
cam_thread = getattr(env, "cam_thread", None)
|
||||
if cam_thread is not None:
|
||||
cam_thread.join(timeout=1.0)
|
||||
viewer = getattr(env, "viewer", None)
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
def test_scripted_policy_reaches_max_reward_on_canonical_case(self):
|
||||
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
||||
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
||||
self.assertIsNotNone(policy_cls)
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
|
||||
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
||||
policy = policy_cls(inject_noise=False)
|
||||
max_reward = float("-inf")
|
||||
|
||||
try:
|
||||
env.reset(task_state)
|
||||
for step in range(700):
|
||||
action = policy.predict(task_state, step)
|
||||
env.step(action)
|
||||
max_reward = max(max_reward, float(env.rew))
|
||||
self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}")
|
||||
finally:
|
||||
env.exit_flag = True
|
||||
cam_thread = getattr(env, "cam_thread", None)
|
||||
if cam_thread is not None:
|
||||
cam_thread.join(timeout=1.0)
|
||||
viewer = getattr(env, "viewer", None)
|
||||
if viewer is not None:
|
||||
viewer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -114,7 +114,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
is_render=False,
|
||||
control_freq=30,
|
||||
is_interpolate=True,
|
||||
cam_view="angle",
|
||||
cam_view="left_side",
|
||||
)
|
||||
|
||||
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
|
||||
@@ -123,11 +123,11 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
env.mj_data = object()
|
||||
env.exit_flag = False
|
||||
env.is_render = False
|
||||
env.cam = "angle"
|
||||
env.cam = "left_side"
|
||||
env.r_vis = None
|
||||
env.l_vis = None
|
||||
env.top = None
|
||||
env.angle = None
|
||||
env.left_side = None
|
||||
env.front = None
|
||||
|
||||
with mock.patch(
|
||||
@@ -144,7 +144,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertIsNotNone(env.r_vis)
|
||||
self.assertIsNotNone(env.l_vis)
|
||||
self.assertIsNotNone(env.top)
|
||||
self.assertIsNotNone(env.angle)
|
||||
self.assertIsNotNone(env.left_side)
|
||||
self.assertIsNotNone(env.front)
|
||||
|
||||
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
|
||||
@@ -254,19 +254,19 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
self.assertAlmostEqual(summary["avg_reward"], 3.75)
|
||||
self.assertEqual(summary["num_episodes"], 2)
|
||||
|
||||
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self):
|
||||
def test_run_eval_uses_air_insert_sampler_for_socket_peg_task(self):
|
||||
self.assertTrue(
|
||||
hasattr(eval_vla, "sample_air_insert_ring_bar_state"),
|
||||
"Expected eval_vla to expose the new ring/bar reset sampler",
|
||||
hasattr(eval_vla, "sample_air_insert_socket_peg_state"),
|
||||
"Expected eval_vla to expose the new socket/peg reset sampler",
|
||||
)
|
||||
|
||||
fake_env = _FakeEnv()
|
||||
fake_agent = _FakeAgent()
|
||||
sampled_task_state = {
|
||||
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"socket_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
|
||||
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
"peg_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
|
||||
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
||||
}
|
||||
cfg = OmegaConf.create(
|
||||
{
|
||||
@@ -276,7 +276,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
"num_episodes": 1,
|
||||
"max_timesteps": 1,
|
||||
"device": "cpu",
|
||||
"task_name": "sim_air_insert_ring_bar",
|
||||
"task_name": "sim_air_insert_socket_peg",
|
||||
"camera_names": ["front"],
|
||||
"use_smoothing": False,
|
||||
"smooth_alpha": 0.3,
|
||||
@@ -296,12 +296,12 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
return_value=fake_env,
|
||||
) as make_env, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_air_insert_ring_bar_state",
|
||||
"sample_air_insert_socket_peg_state",
|
||||
return_value=sampled_task_state,
|
||||
) as ring_bar_sampler, mock.patch.object(
|
||||
) as socket_peg_sampler, mock.patch.object(
|
||||
eval_vla,
|
||||
"sample_transfer_pose",
|
||||
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"),
|
||||
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_socket_peg"),
|
||||
), mock.patch.object(
|
||||
eval_vla,
|
||||
"execute_policy_action",
|
||||
@@ -312,8 +312,8 @@ class EvalVLAHeadlessTest(unittest.TestCase):
|
||||
):
|
||||
eval_vla._run_eval(cfg)
|
||||
|
||||
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True)
|
||||
ring_bar_sampler.assert_called_once_with()
|
||||
make_env.assert_called_once_with("sim_air_insert_socket_peg", headless=True)
|
||||
socket_peg_sampler.assert_called_once_with()
|
||||
execute_policy_action.assert_called_once()
|
||||
self.assertEqual(fake_env.reset_calls, [sampled_task_state])
|
||||
|
||||
|
||||
@@ -59,15 +59,15 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
|
||||
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
|
||||
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
|
||||
|
||||
def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self):
|
||||
BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None)
|
||||
def test_bidianamed_socket_peg_resolves_robot_asset_paths_independent_of_cwd(self):
|
||||
BiDianaMedSocketPeg = getattr(diana_med, 'BiDianaMedSocketPeg', None)
|
||||
self.assertIsNotNone(
|
||||
BiDianaMedRingBar,
|
||||
'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar',
|
||||
BiDianaMedSocketPeg,
|
||||
'Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg',
|
||||
)
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml'
|
||||
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml'
|
||||
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
|
||||
xml_calls = []
|
||||
|
||||
@@ -89,7 +89,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
|
||||
'roboimi.assets.robots.arm_base.KDL_utils',
|
||||
_FakeKDL,
|
||||
):
|
||||
BiDianaMedRingBar()
|
||||
BiDianaMedSocketPeg()
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user