import importlib import inspect import pathlib import unittest from unittest import mock import xml.etree.ElementTree as ET import numpy as np from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.utils import act_ex_utils from roboimi.utils.constants import SIM_TASK_CONFIGS TASK_NAME = "sim_air_insert_socket_peg" class AirInsertTaskRegistrationTest(unittest.TestCase): def test_sim_task_configs_registers_air_insert_socket_peg(self): self.assertIn(TASK_NAME, SIM_TASK_CONFIGS) self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS) self.assertGreaterEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 1000) self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg")) def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self): sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None) self.assertIsNotNone( sampler, "Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()", ) self.assertFalse( hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"), "air insert sampler should use socket/peg naming after the task rename", ) task_state = sampler() self.assertEqual( list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"], ) self.assertEqual(task_state["socket_pos"].shape, (3,)) self.assertEqual(task_state["socket_quat"].shape, (4,)) self.assertEqual(task_state["peg_pos"].shape, (3,)) self.assertEqual(task_state["peg_quat"].shape, (4,)) def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self): sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None) self.assertIsNotNone(sampler) task_state = sampler() np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)) self.assertGreaterEqual(task_state["socket_pos"][0], -0.20) self.assertLessEqual(task_state["socket_pos"][0], -0.05) self.assertGreaterEqual(task_state["socket_pos"][1], 0.70) self.assertLessEqual(task_state["socket_pos"][1], 1.00) self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472) self.assertGreaterEqual(task_state["peg_pos"][0], 0.05) self.assertLessEqual(task_state["peg_pos"][0], 0.20) self.assertGreaterEqual(task_state["peg_pos"][1], 0.70) self.assertLessEqual(task_state["peg_pos"][1], 1.00) self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46) def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) self.assertIsNotNone(air_insert_cls) diana_med = importlib.import_module("roboimi.assets.robots.diana_med") socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None) self.assertIsNotNone( socket_peg_robot_cls, "Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg", ) fake_env = object() with mock.patch.object( diana_med, "BiDianaMedSocketPeg", return_value="robot", ), mock.patch.object( air_insert_env, "DualDianaMed_Air_Insert", return_value=fake_env, ) as env_cls: env = make_sim_env(TASK_NAME, headless=True) self.assertIs(env, fake_env) env_cls.assert_called_once_with( robot="robot", is_render=False, control_freq=30, is_interpolate=True, cam_view="left_side", ) def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self): xml_path = ( pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed/table_square.xml" ) root = ET.parse(xml_path).getroot() cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")} self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera") self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera") left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ") self.assertLess(float(left_side_pos[0]), 0.0) self.assertEqual(cameras["left_side"].get("mode"), "targetbody") self.assertEqual(cameras["left_side"].get("target"), "table") class AirInsertResetAndStateHelpersTest(unittest.TestCase): def test_set_socket_peg_task_state_writes_free_joint_qpos(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") setter = getattr(air_insert_env, "set_socket_peg_task_state", None) self.assertIsNotNone( setter, "Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state", ) socket_qpos = np.zeros(7, dtype=np.float64) peg_qpos = np.zeros(7, dtype=np.float64) class _FakeJoint: def __init__(self, qpos): self.qpos = qpos class _FakeData: def joint(self, name): if name == "blue_socket_joint": return _FakeJoint(socket_qpos) if name == "red_peg_joint": return _FakeJoint(peg_qpos) raise AssertionError(f"Unexpected joint name: {name}") task_state = { "socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), "peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64), } setter(_FakeData(), task_state) np.testing.assert_array_equal( socket_qpos, np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), ) np.testing.assert_array_equal( peg_qpos, np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), ) def test_get_socket_peg_env_state_returns_stable_14d_vector(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") getter = getattr(air_insert_env, "get_socket_peg_env_state", None) self.assertIsNotNone( getter, "Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state", ) socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64) class _FakeJoint: def __init__(self, qpos): self.qpos = qpos class _FakeData: def joint(self, name): if name == "blue_socket_joint": return _FakeJoint(socket_qpos) if name == "red_peg_joint": return _FakeJoint(peg_qpos) raise AssertionError(f"Unexpected joint name: {name}") env_state = getter(_FakeData()) self.assertEqual(env_state.shape, (14,)) np.testing.assert_array_equal( env_state, np.array( [-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64, ), ) def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None) self.assertIsNotNone(env_cls) source = inspect.getsource(env_cls) self.assertNotIn("_update_scripted_grasped_objects", source) self.assertNotIn("_scripted_", source) self.assertNotIn("_stabilize_ring_grasp", source) self.assertNotIn("_ring_grasp_locked", source) get_reward_source = inspect.getsource(env_cls._get_reward) self.assertNotIn("ring_block", get_reward_source) self.assertNotIn("bar_block", get_reward_source) def test_socket_peg_xml_defines_active_socket_and_peg_objects(self): asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed" xml_path = asset_dir / "socket_peg_objects.xml" self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml") self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed") root = ET.parse(xml_path).getroot() body_names = {body.attrib.get("name") for body in root.findall(".//body")} geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")} joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")} self.assertIn("socket", body_names) self.assertIn("peg", body_names) self.assertNotIn("ring_block", body_names) self.assertNotIn("bar_block", body_names) self.assertIn("blue_socket_joint", joint_names) self.assertIn("red_peg_joint", joint_names) for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"): self.assertIn(geom_name, geom_names) def test_socket_peg_wrapper_includes_socket_peg_objects(self): xml_path = ( pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml" ) self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name") root = ET.parse(xml_path).getroot() includes = [include.attrib.get("file") for include in root.findall(".//include")] self.assertIn("./socket_peg_objects.xml", includes) self.assertNotIn("./ring_bar_objects.xml", includes) class AirInsertRewardAndSuccessTest(unittest.TestCase): @staticmethod def _make_env_state( socket_pos=(0.0, 0.0, 0.472), socket_quat=(1.0, 0.0, 0.0, 0.0), peg_pos=(0.0, 0.0, 0.46), peg_quat=(1.0, 0.0, 0.0, 0.0), ): return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64) def test_compute_air_insert_reward_counts_left_contact_stage(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) self.assertIsNotNone(reward_fn) reward = reward_fn( contact_pairs=[ ("socket-1", "l_finger_left"), ("socket-1", "table"), ("red_peg", "table"), ], env_state=self._make_env_state(), ) self.assertEqual(reward, 1) def test_compute_air_insert_reward_counts_right_contact_stage(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward = reward_fn( contact_pairs=[ ("socket-1", "l_finger_left"), ("red_peg", "l_finger_right"), ("socket-1", "table"), ("red_peg", "table"), ], env_state=self._make_env_state(), ) self.assertEqual(reward, 2) def test_compute_air_insert_reward_counts_lift_stages(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward = reward_fn( contact_pairs=[ ("socket-1", "l_finger_left"), ("red_peg", "l_finger_right"), ], env_state=self._make_env_state(), ) self.assertEqual(reward, 4) def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward = reward_fn( contact_pairs=[ ("socket-3", "r_fingertip_g0_vis_left"), ("red_peg", "l_fingertip_g0_vis_right"), ], env_state=self._make_env_state(), ) self.assertEqual( reward, 4, "visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts", ) def test_peg_inserted_into_socket_uses_pin_contact(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None) self.assertIsNotNone( success_fn, "Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket", ) self.assertTrue(success_fn([("red_peg", "pin")])) self.assertTrue(success_fn([("pin", "red_peg")])) self.assertFalse(success_fn([("red_peg", "socket-1")])) def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward = reward_fn( contact_pairs=[ ("socket-1", "l_finger_left"), ("red_peg", "l_finger_right"), ("socket-1", "table"), ("red_peg", "pin"), ], env_state=self._make_env_state(), ) self.assertEqual(reward, 3) def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self): air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env") reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None) reward = reward_fn( contact_pairs=[ ("socket-1", "l_finger_left"), ("red_peg", "l_finger_right"), ("red_peg", "pin"), ], env_state=self._make_env_state(), ) self.assertEqual(reward, 5) class AirInsertPolicyAndSmokeTest(unittest.TestCase): @staticmethod def _canonical_task_state(): return { "socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), "peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), } def test_air_insert_policy_emits_valid_16d_action(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) task_state = act_ex_utils.sample_air_insert_socket_peg_state() policy = policy_cls(inject_noise=False) action = policy.predict(task_state, 0) self.assertEqual(action.shape, (16,)) np.testing.assert_array_equal(action[-2:], np.array([100, 100])) def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) task_state = self._canonical_task_state() policy = policy_cls(inject_noise=False) policy.generate_trajectory(task_state) start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T) end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T) self.assertLess( end_waypoint["xyz"][0], start_waypoint["xyz"][0] - 0.10, "front-view right-to-left peg insertion should decrease world x substantially", ) self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02) expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168) self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02) self.assertGreater(float(start_waypoint["xyz"][2]), 0.70) def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) task_state = { "socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), "peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), } policy = policy_cls(inject_noise=False) policy.generate_trajectory(task_state) left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180) right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180) action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11) expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset]) expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01]) np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6) np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6) self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket") self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg") def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) base_state = { "socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32), "socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), "peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32), "peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32), } shifted_state = dict(base_state) shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32) base_policy = policy_cls(inject_noise=False) base_policy.generate_trajectory(base_state) shifted_policy = policy_cls(inject_noise=False) shifted_policy.generate_trajectory(shifted_state) base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450) shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450) np.testing.assert_allclose( base_hold["xyz"][:2], base_state["socket_pos"][:2] + np.array([-0.078, 0.0]), atol=1e-6, ) np.testing.assert_allclose( shifted_hold["xyz"][:2], shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]), atol=1e-6, ) def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) task_state = self._canonical_task_state() policy = policy_cls(inject_noise=False) for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]): action = policy.predict(task_state, step) self.assertEqual(action.shape, (16,)) def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self): rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes") sampler_fn = getattr(rollout_module, "sample_task_state", None) policy_factory = getattr(rollout_module, "make_policy", None) self.assertIsNotNone(sampler_fn) self.assertIsNotNone(policy_factory) task_state = sampler_fn(TASK_NAME) self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"]) policy = policy_factory(TASK_NAME, inject_noise=False) self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy") def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self): policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy") policy_cls = getattr(policy_module, "TestAirInsertPolicy", None) self.assertIsNotNone(policy_cls) task_state = act_ex_utils.sample_air_insert_socket_peg_state() env = make_sim_env(TASK_NAME, headless=True) policy = policy_cls(inject_noise=False) try: env.reset(task_state) action = policy.predict(task_state, 0) env.step(action) self.assertIsNotNone(env.obs) self.assertIn("qpos", env.obs) self.assertIn("images", env.obs) finally: env.exit_flag = True cam_thread = getattr(env, "cam_thread", None) if cam_thread is not None: cam_thread.join(timeout=1.0) viewer = getattr(env, "viewer", None) if viewer is not None: viewer.close() if __name__ == "__main__": unittest.main()