feat(sim): switch air insert task to socket peg

This commit is contained in:
Logic
2026-05-02 17:34:43 +08:00
parent 4c3646a3d5
commit 5c5cb299e9
16 changed files with 594 additions and 630 deletions

View File

@@ -1,6 +1,9 @@
import importlib
import inspect
import pathlib
import unittest
from unittest import mock
import xml.etree.ElementTree as ET
import numpy as np
@@ -9,83 +12,80 @@ from roboimi.utils import act_ex_utils
from roboimi.utils.constants import SIM_TASK_CONFIGS
class AirInsertTaskRegistrationTest(unittest.TestCase):
def test_sim_task_configs_registers_air_insert_ring_bar(self):
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
TASK_NAME = "sim_air_insert_socket_peg"
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
class AirInsertTaskRegistrationTest(unittest.TestCase):
def test_sim_task_configs_registers_air_insert_socket_peg(self):
self.assertIn(TASK_NAME, SIM_TASK_CONFIGS)
self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
self.assertGreaterEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 1000)
self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg"))
def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
"Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()",
)
self.assertFalse(
hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"),
"air insert sampler should use socket/peg naming after the task rename",
)
task_state = sampler()
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
["socket_pos", "socket_quat", "peg_pos", "peg_quat"],
)
self.assertEqual(task_state["ring_pos"].shape, (3,))
self.assertEqual(task_state["ring_quat"].shape, (4,))
self.assertEqual(task_state["bar_pos"].shape, (3,))
self.assertEqual(task_state["bar_quat"].shape, (4,))
self.assertEqual(task_state["socket_pos"].shape, (3,))
self.assertEqual(task_state["socket_quat"].shape, (4,))
self.assertEqual(task_state["peg_pos"].shape, (3,))
self.assertEqual(task_state["peg_quat"].shape, (4,))
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self):
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
self.assertIsNotNone(sampler)
task_state = sampler()
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
try:
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
except Exception as exc:
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
self.assertGreaterEqual(task_state["socket_pos"][0], -0.20)
self.assertLessEqual(task_state["socket_pos"][0], -0.05)
self.assertGreaterEqual(task_state["socket_pos"][1], 0.70)
self.assertLessEqual(task_state["socket_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472)
self.assertGreaterEqual(task_state["peg_pos"][0], 0.05)
self.assertLessEqual(task_state["peg_pos"][0], 0.20)
self.assertGreaterEqual(task_state["peg_pos"][1], 0.70)
self.assertLessEqual(task_state["peg_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46)
def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone(
air_insert_cls,
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
)
self.assertIsNotNone(air_insert_cls)
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None)
self.assertIsNotNone(
ring_bar_robot_cls,
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
socket_peg_robot_cls,
"Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg",
)
fake_env = object()
with mock.patch.object(
diana_med,
"BiDianaMedRingBar",
"BiDianaMedSocketPeg",
return_value="robot",
), mock.patch.object(
air_insert_env,
"DualDianaMed_Air_Insert",
return_value=fake_env,
) as env_cls:
try:
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
except Exception as exc:
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
env = make_sim_env(TASK_NAME, headless=True)
self.assertIs(env, fake_env)
env_cls.assert_called_once_with(
@@ -93,21 +93,36 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
cam_view="left_side",
)
def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self):
xml_path = (
pathlib.Path(__file__).resolve().parents[1]
/ "roboimi/assets/models/manipulators/DianaMed/table_square.xml"
)
root = ET.parse(xml_path).getroot()
cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")}
self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera")
self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera")
left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ")
self.assertLess(float(left_side_pos[0]), 0.0)
self.assertEqual(cameras["left_side"].get("mode"), "targetbody")
self.assertEqual(cameras["left_side"].get("target"), "table")
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
def test_set_socket_peg_task_state_writes_free_joint_qpos(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
setter = getattr(air_insert_env, "set_socket_peg_task_state", None)
self.assertIsNotNone(
setter,
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
"Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state",
)
ring_qpos = np.zeros(7, dtype=np.float64)
bar_qpos = np.zeros(7, dtype=np.float64)
socket_qpos = np.zeros(7, dtype=np.float64)
peg_qpos = np.zeros(7, dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
@@ -115,40 +130,40 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
if name == "blue_socket_joint":
return _FakeJoint(socket_qpos)
if name == "red_peg_joint":
return _FakeJoint(peg_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
task_state = {
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
}
setter(_FakeData(), task_state)
np.testing.assert_array_equal(
ring_qpos,
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
socket_qpos,
np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
np.testing.assert_array_equal(
bar_qpos,
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
peg_qpos,
np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
def test_get_socket_peg_env_state_returns_stable_14d_vector(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
getter = getattr(air_insert_env, "get_socket_peg_env_state", None)
self.assertIsNotNone(
getter,
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
"Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state",
)
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
@@ -156,10 +171,10 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
if name == "blue_socket_joint":
return _FakeJoint(socket_qpos)
if name == "red_peg_joint":
return _FakeJoint(peg_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
env_state = getter(_FakeData())
@@ -168,38 +183,78 @@ class AirInsertResetAndStateHelpersTest(unittest.TestCase):
np.testing.assert_array_equal(
env_state,
np.array(
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
[-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0],
dtype=np.float64,
),
)
def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone(env_cls)
source = inspect.getsource(env_cls)
self.assertNotIn("_update_scripted_grasped_objects", source)
self.assertNotIn("_scripted_", source)
self.assertNotIn("_stabilize_ring_grasp", source)
self.assertNotIn("_ring_grasp_locked", source)
get_reward_source = inspect.getsource(env_cls._get_reward)
self.assertNotIn("ring_block", get_reward_source)
self.assertNotIn("bar_block", get_reward_source)
def test_socket_peg_xml_defines_active_socket_and_peg_objects(self):
asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed"
xml_path = asset_dir / "socket_peg_objects.xml"
self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml")
self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed")
root = ET.parse(xml_path).getroot()
body_names = {body.attrib.get("name") for body in root.findall(".//body")}
geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")}
joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")}
self.assertIn("socket", body_names)
self.assertIn("peg", body_names)
self.assertNotIn("ring_block", body_names)
self.assertNotIn("bar_block", body_names)
self.assertIn("blue_socket_joint", joint_names)
self.assertIn("red_peg_joint", joint_names)
for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"):
self.assertIn(geom_name, geom_names)
def test_socket_peg_wrapper_includes_socket_peg_objects(self):
xml_path = (
pathlib.Path(__file__).resolve().parents[1]
/ "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml"
)
self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name")
root = ET.parse(xml_path).getroot()
includes = [include.attrib.get("file") for include in root.findall(".//include")]
self.assertIn("./socket_peg_objects.xml", includes)
self.assertNotIn("./ring_bar_objects.xml", includes)
class AirInsertRewardAndSuccessTest(unittest.TestCase):
@staticmethod
def _make_env_state(
ring_pos=(0.0, 0.0, 0.50),
ring_quat=(1.0, 0.0, 0.0, 0.0),
bar_pos=(0.0, 0.0, 0.50),
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
socket_pos=(0.0, 0.0, 0.472),
socket_quat=(1.0, 0.0, 0.0, 0.0),
peg_pos=(0.0, 0.0, 0.46),
peg_quat=(1.0, 0.0, 0.0, 0.0),
):
return np.array(
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
dtype=np.float64,
)
return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64)
def test_compute_air_insert_reward_counts_left_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
self.assertIsNotNone(
reward_fn,
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
)
self.assertIsNotNone(reward_fn)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("ring_block_north", "table"),
("bar_block", "table"),
("socket-1", "l_finger_left"),
("socket-1", "table"),
("red_peg", "table"),
],
env_state=self._make_env_state(),
)
@@ -212,10 +267,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
("bar_block", "table"),
("socket-1", "l_finger_left"),
("red_peg", "l_finger_right"),
("socket-1", "table"),
("red_peg", "table"),
],
env_state=self._make_env_state(),
)
@@ -228,47 +283,43 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("socket-1", "l_finger_left"),
("red_peg", "l_finger_right"),
],
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
env_state=self._make_env_state(),
)
self.assertEqual(reward, 4)
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("socket-3", "r_fingertip_g0_vis_left"),
("red_peg", "l_fingertip_g0_vis_right"),
],
env_state=self._make_env_state(),
)
self.assertEqual(
reward,
4,
"visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts",
)
def test_peg_inserted_into_socket_uses_pin_contact(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None)
self.assertIsNotNone(
success_fn,
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
"Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket",
)
self.assertTrue(
success_fn(
self._make_env_state(),
)
)
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
)
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
)
)
self.assertTrue(success_fn([("red_peg", "pin")]))
self.assertTrue(success_fn([("pin", "red_peg")]))
self.assertFalse(success_fn([("red_peg", "socket-1")]))
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
@@ -276,9 +327,10 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
("socket-1", "l_finger_left"),
("red_peg", "l_finger_right"),
("socket-1", "table"),
("red_peg", "pin"),
],
env_state=self._make_env_state(),
)
@@ -291,8 +343,9 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("socket-1", "l_finger_left"),
("red_peg", "l_finger_right"),
("red_peg", "pin"),
],
env_state=self._make_env_state(),
)
@@ -301,41 +354,129 @@ class AirInsertRewardAndSuccessTest(unittest.TestCase):
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
@staticmethod
def _canonical_task_state():
return {
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
def test_air_insert_policy_emits_valid_16d_action(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(
policy_cls,
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
)
self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
policy = policy_cls(inject_noise=False)
action = policy.predict(task_state, 0)
self.assertEqual(action.shape, (16,))
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = self._canonical_task_state()
policy = policy_cls(inject_noise=False)
policy.generate_trajectory(task_state)
start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T)
end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T)
self.assertLess(
end_waypoint["xyz"][0],
start_waypoint["xyz"][0] - 0.10,
"front-view right-to-left peg insertion should decrease world x substantially",
)
self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02)
expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168)
self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02)
self.assertGreater(float(start_waypoint["xyz"][2]), 0.70)
def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
policy = policy_cls(inject_noise=False)
policy.generate_trajectory(task_state)
left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180)
right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180)
action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11)
expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset])
expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01])
np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6)
np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6)
self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket")
self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg")
def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
base_state = {
"socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
shifted_state = dict(base_state)
shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32)
base_policy = policy_cls(inject_noise=False)
base_policy.generate_trajectory(base_state)
shifted_policy = policy_cls(inject_noise=False)
shifted_policy.generate_trajectory(shifted_state)
base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450)
shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450)
np.testing.assert_allclose(
base_hold["xyz"][:2],
base_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
atol=1e-6,
)
np.testing.assert_allclose(
shifted_hold["xyz"][:2],
shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
atol=1e-6,
)
def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = self._canonical_task_state()
policy = policy_cls(inject_noise=False)
for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]):
action = policy.predict(task_state, step)
self.assertEqual(action.shape, (16,))
def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self):
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
sampler_fn = getattr(rollout_module, "sample_task_state", None)
policy_factory = getattr(rollout_module, "make_policy", None)
self.assertIsNotNone(
sampler_fn,
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
)
self.assertIsNotNone(
policy_factory,
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
)
self.assertIsNotNone(sampler_fn)
self.assertIsNotNone(policy_factory)
task_state = sampler_fn("sim_air_insert_ring_bar")
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
task_state = sampler_fn(TASK_NAME)
self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"])
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
policy = policy_factory(TASK_NAME, inject_noise=False)
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
@@ -343,8 +484,8 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
env = make_sim_env(TASK_NAME, headless=True)
policy = policy_cls(inject_noise=False)
try:
@@ -363,115 +504,6 @@ class AirInsertPolicyAndSmokeTest(unittest.TestCase):
if viewer is not None:
viewer.close()
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
def is_cross_arm_pair(a, b):
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
try:
env.reset(task_state)
for step in range(460):
action = policy.predict(task_state, step)
env.step(action)
pairs = []
for i in range(env.mj_data.ncon):
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
pairs.append((geom1, geom2))
self.assertFalse(
pairs,
f"cross-arm contact detected at step {step}: {pairs[:5]}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_keeps_ring_airborne_through_hold_phase_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
for step in range(400):
action = policy.predict(task_state, step)
env.step(action)
ring_z = float(env.get_env_state()[2])
self.assertGreater(
ring_z,
0.55,
f"ring dropped before hold phase completed, final z={ring_z:.4f}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_reaches_max_reward_on_canonical_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.11884121, 0.800019, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12783867, 0.73399246, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
max_reward = float("-inf")
try:
env.reset(task_state)
for step in range(700):
action = policy.predict(task_state, step)
env.step(action)
max_reward = max(max_reward, float(env.rew))
self.assertEqual(max_reward, 5.0, f"expected canonical rollout to reach reward 5, got {max_reward}")
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__":
unittest.main()

View File

@@ -114,7 +114,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
cam_view="left_side",
)
def test_camera_viewer_headless_updates_images_without_gui_calls(self):
@@ -123,11 +123,11 @@ class EvalVLAHeadlessTest(unittest.TestCase):
env.mj_data = object()
env.exit_flag = False
env.is_render = False
env.cam = "angle"
env.cam = "left_side"
env.r_vis = None
env.l_vis = None
env.top = None
env.angle = None
env.left_side = None
env.front = None
with mock.patch(
@@ -144,7 +144,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertIsNotNone(env.r_vis)
self.assertIsNotNone(env.l_vis)
self.assertIsNotNone(env.top)
self.assertIsNotNone(env.angle)
self.assertIsNotNone(env.left_side)
self.assertIsNotNone(env.front)
def test_eval_main_headless_skips_render_and_still_executes_policy(self):
@@ -254,19 +254,19 @@ class EvalVLAHeadlessTest(unittest.TestCase):
self.assertAlmostEqual(summary["avg_reward"], 3.75)
self.assertEqual(summary["num_episodes"], 2)
def test_run_eval_uses_air_insert_sampler_for_ring_bar_task(self):
def test_run_eval_uses_air_insert_sampler_for_socket_peg_task(self):
self.assertTrue(
hasattr(eval_vla, "sample_air_insert_ring_bar_state"),
"Expected eval_vla to expose the new ring/bar reset sampler",
hasattr(eval_vla, "sample_air_insert_socket_peg_state"),
"Expected eval_vla to expose the new socket/peg reset sampler",
)
fake_env = _FakeEnv()
fake_agent = _FakeAgent()
sampled_task_state = {
"ring_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"socket_pos": np.array([-0.10, 0.80, 0.47], dtype=np.float32),
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"peg_pos": np.array([0.10, 0.82, 0.47], dtype=np.float32),
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
cfg = OmegaConf.create(
{
@@ -276,7 +276,7 @@ class EvalVLAHeadlessTest(unittest.TestCase):
"num_episodes": 1,
"max_timesteps": 1,
"device": "cpu",
"task_name": "sim_air_insert_ring_bar",
"task_name": "sim_air_insert_socket_peg",
"camera_names": ["front"],
"use_smoothing": False,
"smooth_alpha": 0.3,
@@ -296,12 +296,12 @@ class EvalVLAHeadlessTest(unittest.TestCase):
return_value=fake_env,
) as make_env, mock.patch.object(
eval_vla,
"sample_air_insert_ring_bar_state",
"sample_air_insert_socket_peg_state",
return_value=sampled_task_state,
) as ring_bar_sampler, mock.patch.object(
) as socket_peg_sampler, mock.patch.object(
eval_vla,
"sample_transfer_pose",
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_ring_bar"),
side_effect=AssertionError("sample_transfer_pose should not be used for sim_air_insert_socket_peg"),
), mock.patch.object(
eval_vla,
"execute_policy_action",
@@ -312,8 +312,8 @@ class EvalVLAHeadlessTest(unittest.TestCase):
):
eval_vla._run_eval(cfg)
make_env.assert_called_once_with("sim_air_insert_ring_bar", headless=True)
ring_bar_sampler.assert_called_once_with()
make_env.assert_called_once_with("sim_air_insert_socket_peg", headless=True)
socket_peg_sampler.assert_called_once_with()
execute_policy_action.assert_called_once()
self.assertEqual(fake_env.reset_calls, [sampled_task_state])

View File

@@ -59,15 +59,15 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
def test_bidianamed_ring_bar_resolves_robot_asset_paths_independent_of_cwd(self):
BiDianaMedRingBar = getattr(diana_med, 'BiDianaMedRingBar', None)
def test_bidianamed_socket_peg_resolves_robot_asset_paths_independent_of_cwd(self):
BiDianaMedSocketPeg = getattr(diana_med, 'BiDianaMedSocketPeg', None)
self.assertIsNotNone(
BiDianaMedRingBar,
'Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar',
BiDianaMedSocketPeg,
'Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg',
)
repo_root = Path(__file__).resolve().parents[1]
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_ring_bar_ee.xml'
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml'
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
xml_calls = []
@@ -89,7 +89,7 @@ class RobotAssetPathResolutionTest(unittest.TestCase):
'roboimi.assets.robots.arm_base.KDL_utils',
_FakeKDL,
):
BiDianaMedRingBar()
BiDianaMedSocketPeg()
finally:
os.chdir(previous_cwd)