511 lines
22 KiB
Python
511 lines
22 KiB
Python
import importlib
|
|
import inspect
|
|
import pathlib
|
|
import unittest
|
|
from unittest import mock
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import numpy as np
|
|
|
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
from roboimi.utils import act_ex_utils
|
|
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
|
|
|
|
|
TASK_NAME = "sim_air_insert_socket_peg"
|
|
|
|
|
|
class AirInsertTaskRegistrationTest(unittest.TestCase):
|
|
def test_sim_task_configs_registers_air_insert_socket_peg(self):
|
|
self.assertIn(TASK_NAME, SIM_TASK_CONFIGS)
|
|
self.assertNotIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
|
|
self.assertEqual(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"], 750)
|
|
self.assertEqual(SIM_TASK_CONFIGS[TASK_NAME]["camera_names"], ["l_vis", "r_vis", "front"])
|
|
self.assertTrue(SIM_TASK_CONFIGS[TASK_NAME]["dataset_dir"].endswith("/sim_air_insert_socket_peg"))
|
|
|
|
def test_sample_air_insert_socket_peg_state_returns_explicit_named_mapping(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
|
|
self.assertIsNotNone(
|
|
sampler,
|
|
"Expected roboimi.utils.act_ex_utils.sample_air_insert_socket_peg_state()",
|
|
)
|
|
self.assertFalse(
|
|
hasattr(act_ex_utils, "sample_air_insert_ring_bar_state"),
|
|
"air insert sampler should use socket/peg naming after the task rename",
|
|
)
|
|
|
|
task_state = sampler()
|
|
|
|
self.assertEqual(
|
|
list(task_state.keys()),
|
|
["socket_pos", "socket_quat", "peg_pos", "peg_quat"],
|
|
)
|
|
self.assertEqual(task_state["socket_pos"].shape, (3,))
|
|
self.assertEqual(task_state["socket_quat"].shape, (4,))
|
|
self.assertEqual(task_state["peg_pos"].shape, (3,))
|
|
self.assertEqual(task_state["peg_quat"].shape, (4,))
|
|
|
|
def test_sample_air_insert_socket_peg_state_uses_fixed_quats_and_left_right_planar_ranges(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_socket_peg_state", None)
|
|
self.assertIsNotNone(sampler)
|
|
|
|
task_state = sampler()
|
|
|
|
np.testing.assert_array_equal(task_state["socket_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
|
|
np.testing.assert_array_equal(task_state["peg_quat"], np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32))
|
|
self.assertGreaterEqual(task_state["socket_pos"][0], -0.20)
|
|
self.assertLessEqual(task_state["socket_pos"][0], -0.05)
|
|
self.assertGreaterEqual(task_state["socket_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["socket_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["socket_pos"][2]), 0.472)
|
|
self.assertGreaterEqual(task_state["peg_pos"][0], 0.05)
|
|
self.assertLessEqual(task_state["peg_pos"][0], 0.20)
|
|
self.assertGreaterEqual(task_state["peg_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["peg_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["peg_pos"][2]), 0.46)
|
|
|
|
def test_make_sim_env_dispatches_air_insert_socket_peg_headless(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
|
self.assertIsNotNone(air_insert_cls)
|
|
|
|
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
|
|
socket_peg_robot_cls = getattr(diana_med, "BiDianaMedSocketPeg", None)
|
|
self.assertIsNotNone(
|
|
socket_peg_robot_cls,
|
|
"Expected roboimi.assets.robots.diana_med.BiDianaMedSocketPeg",
|
|
)
|
|
|
|
fake_env = object()
|
|
with mock.patch.object(
|
|
diana_med,
|
|
"BiDianaMedSocketPeg",
|
|
return_value="robot",
|
|
), mock.patch.object(
|
|
air_insert_env,
|
|
"DualDianaMed_Air_Insert",
|
|
return_value=fake_env,
|
|
) as env_cls:
|
|
env = make_sim_env(TASK_NAME, headless=True)
|
|
|
|
self.assertIs(env, fake_env)
|
|
env_cls.assert_called_once_with(
|
|
robot="robot",
|
|
is_render=False,
|
|
control_freq=30,
|
|
is_interpolate=True,
|
|
cam_view="left_side",
|
|
)
|
|
|
|
def test_diana_table_scene_uses_left_side_camera_instead_of_angle(self):
|
|
xml_path = (
|
|
pathlib.Path(__file__).resolve().parents[1]
|
|
/ "roboimi/assets/models/manipulators/DianaMed/table_square.xml"
|
|
)
|
|
root = ET.parse(xml_path).getroot()
|
|
cameras = {camera.attrib["name"]: camera.attrib for camera in root.findall(".//camera")}
|
|
|
|
self.assertNotIn("angle", cameras, "DianaMed scene should stop exposing the old angle camera")
|
|
self.assertIn("left_side", cameras, "DianaMed scene should expose the left-side task camera")
|
|
left_side_pos = np.fromstring(cameras["left_side"]["pos"], sep=" ")
|
|
self.assertLess(float(left_side_pos[0]), 0.0)
|
|
self.assertEqual(cameras["left_side"].get("mode"), "targetbody")
|
|
self.assertEqual(cameras["left_side"].get("target"), "table")
|
|
|
|
|
|
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
|
def test_set_socket_peg_task_state_writes_free_joint_qpos(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
setter = getattr(air_insert_env, "set_socket_peg_task_state", None)
|
|
self.assertIsNotNone(
|
|
setter,
|
|
"Expected roboimi.envs.double_air_insert_env.set_socket_peg_task_state",
|
|
)
|
|
|
|
socket_qpos = np.zeros(7, dtype=np.float64)
|
|
peg_qpos = np.zeros(7, dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "blue_socket_joint":
|
|
return _FakeJoint(socket_qpos)
|
|
if name == "red_peg_joint":
|
|
return _FakeJoint(peg_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
task_state = {
|
|
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float64),
|
|
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
"peg_pos": np.array([0.12, 0.91, 0.46], dtype=np.float64),
|
|
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
}
|
|
|
|
setter(_FakeData(), task_state)
|
|
|
|
np.testing.assert_array_equal(
|
|
socket_qpos,
|
|
np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
np.testing.assert_array_equal(
|
|
peg_qpos,
|
|
np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
|
|
def test_get_socket_peg_env_state_returns_stable_14d_vector(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
getter = getattr(air_insert_env, "get_socket_peg_env_state", None)
|
|
self.assertIsNotNone(
|
|
getter,
|
|
"Expected roboimi.envs.double_air_insert_env.get_socket_peg_env_state",
|
|
)
|
|
|
|
socket_qpos = np.array([-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
peg_qpos = np.array([0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "blue_socket_joint":
|
|
return _FakeJoint(socket_qpos)
|
|
if name == "red_peg_joint":
|
|
return _FakeJoint(peg_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
env_state = getter(_FakeData())
|
|
|
|
self.assertEqual(env_state.shape, (14,))
|
|
np.testing.assert_array_equal(
|
|
env_state,
|
|
np.array(
|
|
[-0.12, 0.90, 0.472, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.46, 1.0, 0.0, 0.0, 0.0],
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
|
|
def test_air_insert_env_does_not_script_attach_or_assist_objects_after_reset(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
env_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
|
self.assertIsNotNone(env_cls)
|
|
|
|
source = inspect.getsource(env_cls)
|
|
|
|
self.assertNotIn("_update_scripted_grasped_objects", source)
|
|
self.assertNotIn("_scripted_", source)
|
|
self.assertNotIn("_stabilize_ring_grasp", source)
|
|
self.assertNotIn("_ring_grasp_locked", source)
|
|
get_reward_source = inspect.getsource(env_cls._get_reward)
|
|
self.assertNotIn("ring_block", get_reward_source)
|
|
self.assertNotIn("bar_block", get_reward_source)
|
|
|
|
def test_socket_peg_xml_defines_active_socket_and_peg_objects(self):
|
|
asset_dir = pathlib.Path(__file__).resolve().parents[1] / "roboimi/assets/models/manipulators/DianaMed"
|
|
xml_path = asset_dir / "socket_peg_objects.xml"
|
|
self.assertTrue(xml_path.exists(), "socket/peg objects should live in socket_peg_objects.xml")
|
|
self.assertFalse((asset_dir / "ring_bar_objects.xml").exists(), "old ring_bar_objects.xml should be renamed")
|
|
|
|
root = ET.parse(xml_path).getroot()
|
|
body_names = {body.attrib.get("name") for body in root.findall(".//body")}
|
|
geom_names = {geom.attrib.get("name") for geom in root.findall(".//geom")}
|
|
joint_names = {joint.attrib.get("name") for joint in root.findall(".//joint")}
|
|
|
|
self.assertIn("socket", body_names)
|
|
self.assertIn("peg", body_names)
|
|
self.assertNotIn("ring_block", body_names)
|
|
self.assertNotIn("bar_block", body_names)
|
|
self.assertIn("blue_socket_joint", joint_names)
|
|
self.assertIn("red_peg_joint", joint_names)
|
|
for geom_name in ("socket-1", "socket-2", "socket-3", "socket-4", "pin", "red_peg"):
|
|
self.assertIn(geom_name, geom_names)
|
|
|
|
def test_socket_peg_wrapper_includes_socket_peg_objects(self):
|
|
xml_path = (
|
|
pathlib.Path(__file__).resolve().parents[1]
|
|
/ "roboimi/assets/models/manipulators/DianaMed/bi_diana_socket_peg_ee.xml"
|
|
)
|
|
self.assertTrue(xml_path.exists(), "socket/peg wrapper XML should use the new task name")
|
|
root = ET.parse(xml_path).getroot()
|
|
includes = [include.attrib.get("file") for include in root.findall(".//include")]
|
|
self.assertIn("./socket_peg_objects.xml", includes)
|
|
self.assertNotIn("./ring_bar_objects.xml", includes)
|
|
|
|
|
|
class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
|
@staticmethod
|
|
def _make_env_state(
|
|
socket_pos=(0.0, 0.0, 0.472),
|
|
socket_quat=(1.0, 0.0, 0.0, 0.0),
|
|
peg_pos=(0.0, 0.0, 0.46),
|
|
peg_quat=(1.0, 0.0, 0.0, 0.0),
|
|
):
|
|
return np.array([*socket_pos, *socket_quat, *peg_pos, *peg_quat], dtype=np.float64)
|
|
|
|
def test_compute_air_insert_reward_counts_left_contact_stage(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
self.assertIsNotNone(reward_fn)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-1", "l_finger_left"),
|
|
("socket-1", "table"),
|
|
("red_peg", "table"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 1)
|
|
|
|
def test_compute_air_insert_reward_counts_right_contact_stage(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-1", "l_finger_left"),
|
|
("red_peg", "l_finger_right"),
|
|
("socket-1", "table"),
|
|
("red_peg", "table"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 2)
|
|
|
|
def test_compute_air_insert_reward_counts_lift_stages(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-1", "l_finger_left"),
|
|
("red_peg", "l_finger_right"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 4)
|
|
|
|
def test_compute_air_insert_reward_counts_visual_fingertip_contacts_as_gripper_contacts(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-3", "r_fingertip_g0_vis_left"),
|
|
("red_peg", "l_fingertip_g0_vis_right"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(
|
|
reward,
|
|
4,
|
|
"visual fingertip geoms are collidable in the Diana XML and should count as gripper-object contacts",
|
|
)
|
|
|
|
def test_peg_inserted_into_socket_uses_pin_contact(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
success_fn = getattr(air_insert_env, "peg_inserted_into_socket", None)
|
|
self.assertIsNotNone(
|
|
success_fn,
|
|
"Expected roboimi.envs.double_air_insert_env.peg_inserted_into_socket",
|
|
)
|
|
|
|
self.assertTrue(success_fn([("red_peg", "pin")]))
|
|
self.assertTrue(success_fn([("pin", "red_peg")]))
|
|
self.assertFalse(success_fn([("red_peg", "socket-1")]))
|
|
|
|
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-1", "l_finger_left"),
|
|
("red_peg", "l_finger_right"),
|
|
("socket-1", "table"),
|
|
("red_peg", "pin"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 3)
|
|
|
|
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("socket-1", "l_finger_left"),
|
|
("red_peg", "l_finger_right"),
|
|
("red_peg", "pin"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 5)
|
|
|
|
|
|
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
|
@staticmethod
|
|
def _canonical_task_state():
|
|
return {
|
|
"socket_pos": np.array([-0.12, 0.90, 0.472], dtype=np.float32),
|
|
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
"peg_pos": np.array([0.12, 0.90, 0.46], dtype=np.float32),
|
|
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
}
|
|
|
|
def test_air_insert_policy_emits_valid_16d_action(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
|
|
policy = policy_cls(inject_noise=False)
|
|
action = policy.predict(task_state, 0)
|
|
|
|
self.assertEqual(action.shape, (16,))
|
|
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
|
|
|
|
def test_air_insert_policy_inserts_peg_front_view_right_to_left_along_world_x(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = self._canonical_task_state()
|
|
policy = policy_cls(inject_noise=False)
|
|
policy.generate_trajectory(task_state)
|
|
|
|
start_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_START_T)
|
|
end_waypoint = next(wp for wp in policy.right_trajectory if wp["t"] == policy.INSERT_END_T)
|
|
|
|
self.assertLess(
|
|
end_waypoint["xyz"][0],
|
|
start_waypoint["xyz"][0] - 0.10,
|
|
"front-view right-to-left peg insertion should decrease world x substantially",
|
|
)
|
|
self.assertAlmostEqual(float(end_waypoint["xyz"][1]), float(start_waypoint["xyz"][1]), delta=0.02)
|
|
expected_insert_end_x = float(task_state["socket_pos"][0] + 0.168)
|
|
self.assertAlmostEqual(float(end_waypoint["xyz"][0]), expected_insert_end_x, delta=0.02)
|
|
self.assertGreater(float(start_waypoint["xyz"][2]), 0.70)
|
|
|
|
def test_air_insert_policy_default_left_grasps_socket_and_right_grasps_peg(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = {
|
|
"socket_pos": np.array([-0.18, 0.78, 0.472], dtype=np.float32),
|
|
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
"peg_pos": np.array([0.16, 0.98, 0.46], dtype=np.float32),
|
|
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
}
|
|
|
|
policy = policy_cls(inject_noise=False)
|
|
policy.generate_trajectory(task_state)
|
|
left_close = next(wp for wp in policy.left_trajectory if wp["t"] == 180)
|
|
right_close = next(wp for wp in policy.right_trajectory if wp["t"] == 180)
|
|
action_z_offset = getattr(policy_cls, "ACTION_OBJECT_Z_OFFSET", 0.11)
|
|
expected_socket_pick = task_state["socket_pos"] + np.array([-0.078, 0.0, action_z_offset])
|
|
expected_peg_pick = task_state["peg_pos"] + np.array([0.078, 0.0, action_z_offset + 0.01])
|
|
|
|
np.testing.assert_allclose(left_close["xyz"], expected_socket_pick, atol=1e-6)
|
|
np.testing.assert_allclose(right_close["xyz"], expected_peg_pick, atol=1e-6)
|
|
self.assertLess(left_close["gripper"], 0, "default policy should close the left gripper on the socket")
|
|
self.assertLess(right_close["gripper"], 0, "default policy should close the right gripper on the peg")
|
|
|
|
def test_air_insert_policy_socket_hold_tracks_socket_xy_without_sweeping_laterally(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
base_state = {
|
|
"socket_pos": np.array([-0.20, 0.72, 0.472], dtype=np.float32),
|
|
"socket_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
"peg_pos": np.array([0.14, 0.76, 0.46], dtype=np.float32),
|
|
"peg_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
}
|
|
shifted_state = dict(base_state)
|
|
shifted_state["socket_pos"] = np.array([-0.06, 0.99, 0.472], dtype=np.float32)
|
|
|
|
base_policy = policy_cls(inject_noise=False)
|
|
base_policy.generate_trajectory(base_state)
|
|
shifted_policy = policy_cls(inject_noise=False)
|
|
shifted_policy.generate_trajectory(shifted_state)
|
|
|
|
base_hold = next(wp for wp in base_policy.left_trajectory if wp["t"] == 450)
|
|
shifted_hold = next(wp for wp in shifted_policy.left_trajectory if wp["t"] == 450)
|
|
np.testing.assert_allclose(
|
|
base_hold["xyz"][:2],
|
|
base_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
|
|
atol=1e-6,
|
|
)
|
|
np.testing.assert_allclose(
|
|
shifted_hold["xyz"][:2],
|
|
shifted_state["socket_pos"][:2] + np.array([-0.078, 0.0]),
|
|
atol=1e-6,
|
|
)
|
|
|
|
def test_air_insert_policy_predicts_through_full_episode_without_exhausting_waypoints(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = self._canonical_task_state()
|
|
policy = policy_cls(inject_noise=False)
|
|
|
|
for step in range(SIM_TASK_CONFIGS[TASK_NAME]["episode_len"]):
|
|
action = policy.predict(task_state, step)
|
|
self.assertEqual(action.shape, (16,))
|
|
|
|
def test_scripted_rollout_entrypoint_selects_socket_peg_sampler_and_policy(self):
|
|
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
|
|
sampler_fn = getattr(rollout_module, "sample_task_state", None)
|
|
policy_factory = getattr(rollout_module, "make_policy", None)
|
|
self.assertIsNotNone(sampler_fn)
|
|
self.assertIsNotNone(policy_factory)
|
|
|
|
task_state = sampler_fn(TASK_NAME)
|
|
self.assertEqual(list(task_state.keys()), ["socket_pos", "socket_quat", "peg_pos", "peg_quat"])
|
|
|
|
policy = policy_factory(TASK_NAME, inject_noise=False)
|
|
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
|
|
|
|
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = act_ex_utils.sample_air_insert_socket_peg_state()
|
|
env = make_sim_env(TASK_NAME, headless=True)
|
|
policy = policy_cls(inject_noise=False)
|
|
|
|
try:
|
|
env.reset(task_state)
|
|
action = policy.predict(task_state, 0)
|
|
env.step(action)
|
|
self.assertIsNotNone(env.obs)
|
|
self.assertIn("qpos", env.obs)
|
|
self.assertIn("images", env.obs)
|
|
finally:
|
|
env.exit_flag = True
|
|
cam_thread = getattr(env, "cam_thread", None)
|
|
if cam_thread is not None:
|
|
cam_thread.join(timeout=1.0)
|
|
viewer = getattr(env, "viewer", None)
|
|
if viewer is not None:
|
|
viewer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|