411 lines
16 KiB
Python
411 lines
16 KiB
Python
import importlib
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
|
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
from roboimi.utils import act_ex_utils
|
|
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
|
|
|
|
|
class AirInsertTaskRegistrationTest(unittest.TestCase):
|
|
def test_sim_task_configs_registers_air_insert_ring_bar(self):
|
|
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
|
|
|
|
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
|
self.assertIsNotNone(
|
|
sampler,
|
|
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
|
)
|
|
|
|
task_state = sampler()
|
|
|
|
self.assertEqual(
|
|
list(task_state.keys()),
|
|
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
|
)
|
|
self.assertEqual(task_state["ring_pos"].shape, (3,))
|
|
self.assertEqual(task_state["ring_quat"].shape, (4,))
|
|
self.assertEqual(task_state["bar_pos"].shape, (3,))
|
|
self.assertEqual(task_state["bar_quat"].shape, (4,))
|
|
|
|
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
|
self.assertIsNotNone(
|
|
sampler,
|
|
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
|
)
|
|
|
|
task_state = sampler()
|
|
|
|
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
|
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
|
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
|
|
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
|
|
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
|
|
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
|
|
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
|
|
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
|
|
|
|
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
|
|
try:
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
except Exception as exc:
|
|
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
|
|
|
|
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
|
self.assertIsNotNone(
|
|
air_insert_cls,
|
|
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
|
|
)
|
|
|
|
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
|
|
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
|
|
self.assertIsNotNone(
|
|
ring_bar_robot_cls,
|
|
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
|
|
)
|
|
|
|
fake_env = object()
|
|
with mock.patch.object(
|
|
diana_med,
|
|
"BiDianaMedRingBar",
|
|
return_value="robot",
|
|
), mock.patch.object(
|
|
air_insert_env,
|
|
"DualDianaMed_Air_Insert",
|
|
return_value=fake_env,
|
|
) as env_cls:
|
|
try:
|
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
|
except Exception as exc:
|
|
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
|
|
|
|
self.assertIs(env, fake_env)
|
|
env_cls.assert_called_once_with(
|
|
robot="robot",
|
|
is_render=False,
|
|
control_freq=30,
|
|
is_interpolate=True,
|
|
cam_view="angle",
|
|
)
|
|
|
|
|
|
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
|
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
|
|
self.assertIsNotNone(
|
|
setter,
|
|
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
|
|
)
|
|
|
|
ring_qpos = np.zeros(7, dtype=np.float64)
|
|
bar_qpos = np.zeros(7, dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "ring_block_joint":
|
|
return _FakeJoint(ring_qpos)
|
|
if name == "bar_block_joint":
|
|
return _FakeJoint(bar_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
task_state = {
|
|
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
|
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
|
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
}
|
|
|
|
setter(_FakeData(), task_state)
|
|
|
|
np.testing.assert_array_equal(
|
|
ring_qpos,
|
|
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
np.testing.assert_array_equal(
|
|
bar_qpos,
|
|
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
|
|
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
|
|
self.assertIsNotNone(
|
|
getter,
|
|
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
|
|
)
|
|
|
|
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "ring_block_joint":
|
|
return _FakeJoint(ring_qpos)
|
|
if name == "bar_block_joint":
|
|
return _FakeJoint(bar_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
env_state = getter(_FakeData())
|
|
|
|
self.assertEqual(env_state.shape, (14,))
|
|
np.testing.assert_array_equal(
|
|
env_state,
|
|
np.array(
|
|
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
|
|
|
|
class AirInsertRewardAndSuccessTest(unittest.TestCase):
|
|
@staticmethod
|
|
def _make_env_state(
|
|
ring_pos=(0.0, 0.0, 0.50),
|
|
ring_quat=(1.0, 0.0, 0.0, 0.0),
|
|
bar_pos=(0.0, 0.0, 0.50),
|
|
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
|
|
):
|
|
return np.array(
|
|
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
|
|
dtype=np.float64,
|
|
)
|
|
|
|
def test_compute_air_insert_reward_counts_left_contact_stage(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
self.assertIsNotNone(
|
|
reward_fn,
|
|
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
|
|
)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("ring_block_north", "l_finger_left"),
|
|
("ring_block_north", "table"),
|
|
("bar_block", "table"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 1)
|
|
|
|
def test_compute_air_insert_reward_counts_right_contact_stage(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("ring_block_north", "l_finger_left"),
|
|
("bar_block", "l_finger_right"),
|
|
("ring_block_north", "table"),
|
|
("bar_block", "table"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 2)
|
|
|
|
def test_compute_air_insert_reward_counts_lift_stages(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("ring_block_north", "l_finger_left"),
|
|
("bar_block", "l_finger_right"),
|
|
],
|
|
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
|
)
|
|
|
|
self.assertEqual(reward, 4)
|
|
|
|
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
|
self.assertIsNotNone(
|
|
success_fn,
|
|
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
|
|
)
|
|
|
|
self.assertTrue(
|
|
success_fn(
|
|
self._make_env_state(),
|
|
)
|
|
)
|
|
|
|
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
|
|
|
self.assertFalse(
|
|
success_fn(
|
|
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
|
|
)
|
|
)
|
|
|
|
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
|
|
|
|
self.assertFalse(
|
|
success_fn(
|
|
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
|
|
)
|
|
)
|
|
|
|
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("ring_block_north", "l_finger_left"),
|
|
("bar_block", "l_finger_right"),
|
|
("ring_block_north", "table"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 3)
|
|
|
|
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
|
|
|
|
reward = reward_fn(
|
|
contact_pairs=[
|
|
("ring_block_north", "l_finger_left"),
|
|
("bar_block", "l_finger_right"),
|
|
],
|
|
env_state=self._make_env_state(),
|
|
)
|
|
|
|
self.assertEqual(reward, 5)
|
|
|
|
|
|
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
|
|
def test_air_insert_policy_emits_valid_16d_action(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(
|
|
policy_cls,
|
|
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
|
|
)
|
|
|
|
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
|
policy = policy_cls(inject_noise=False)
|
|
action = policy.predict(task_state, 0)
|
|
|
|
self.assertEqual(action.shape, (16,))
|
|
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
|
|
|
|
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
|
|
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
|
|
sampler_fn = getattr(rollout_module, "sample_task_state", None)
|
|
policy_factory = getattr(rollout_module, "make_policy", None)
|
|
self.assertIsNotNone(
|
|
sampler_fn,
|
|
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
|
|
)
|
|
self.assertIsNotNone(
|
|
policy_factory,
|
|
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
|
|
)
|
|
|
|
task_state = sampler_fn("sim_air_insert_ring_bar")
|
|
self.assertEqual(
|
|
list(task_state.keys()),
|
|
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
|
)
|
|
|
|
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
|
|
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
|
|
|
|
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
|
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
|
policy = policy_cls(inject_noise=False)
|
|
|
|
try:
|
|
env.reset(task_state)
|
|
action = policy.predict(task_state, 0)
|
|
env.step(action)
|
|
self.assertIsNotNone(env.obs)
|
|
self.assertIn("qpos", env.obs)
|
|
self.assertIn("images", env.obs)
|
|
finally:
|
|
env.exit_flag = True
|
|
cam_thread = getattr(env, "cam_thread", None)
|
|
if cam_thread is not None:
|
|
cam_thread.join(timeout=1.0)
|
|
viewer = getattr(env, "viewer", None)
|
|
if viewer is not None:
|
|
viewer.close()
|
|
|
|
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
|
|
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
|
|
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
|
|
self.assertIsNotNone(policy_cls)
|
|
|
|
task_state = {
|
|
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
|
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
|
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
|
|
}
|
|
|
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
|
policy = policy_cls(inject_noise=False)
|
|
|
|
def is_cross_arm_pair(a, b):
|
|
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
|
|
|
|
try:
|
|
env.reset(task_state)
|
|
for step in range(460):
|
|
action = policy.predict(task_state, step)
|
|
env.step(action)
|
|
pairs = []
|
|
for i in range(env.mj_data.ncon):
|
|
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
|
|
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
|
|
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
|
|
pairs.append((geom1, geom2))
|
|
self.assertFalse(
|
|
pairs,
|
|
f"cross-arm contact detected at step {step}: {pairs[:5]}",
|
|
)
|
|
finally:
|
|
env.exit_flag = True
|
|
cam_thread = getattr(env, "cam_thread", None)
|
|
if cam_thread is not None:
|
|
cam_thread.join(timeout=1.0)
|
|
viewer = getattr(env, "viewer", None)
|
|
if viewer is not None:
|
|
viewer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|