Files
roboimi/tests/test_air_insert_env.py

411 lines
16 KiB
Python

import importlib
import unittest
from unittest import mock
import numpy as np
from roboimi.envs.double_pos_ctrl_env import make_sim_env
from roboimi.utils import act_ex_utils
from roboimi.utils.constants import SIM_TASK_CONFIGS
class AirInsertTaskRegistrationTest(unittest.TestCase):
def test_sim_task_configs_registers_air_insert_ring_bar(self):
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
task_state = sampler()
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
self.assertEqual(task_state["ring_pos"].shape, (3,))
self.assertEqual(task_state["ring_quat"].shape, (4,))
self.assertEqual(task_state["bar_pos"].shape, (3,))
self.assertEqual(task_state["bar_quat"].shape, (4,))
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
self.assertIsNotNone(
sampler,
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
)
task_state = sampler()
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
try:
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
except Exception as exc:
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
self.assertIsNotNone(
air_insert_cls,
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
)
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
self.assertIsNotNone(
ring_bar_robot_cls,
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
)
fake_env = object()
with mock.patch.object(
diana_med,
"BiDianaMedRingBar",
return_value="robot",
), mock.patch.object(
air_insert_env,
"DualDianaMed_Air_Insert",
return_value=fake_env,
) as env_cls:
try:
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
except Exception as exc:
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
self.assertIs(env, fake_env)
env_cls.assert_called_once_with(
robot="robot",
is_render=False,
control_freq=30,
is_interpolate=True,
cam_view="angle",
)
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
self.assertIsNotNone(
setter,
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
)
ring_qpos = np.zeros(7, dtype=np.float64)
bar_qpos = np.zeros(7, dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
self.qpos = qpos
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
task_state = {
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
}
setter(_FakeData(), task_state)
np.testing.assert_array_equal(
ring_qpos,
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
np.testing.assert_array_equal(
bar_qpos,
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
self.assertIsNotNone(
getter,
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
)
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
self.qpos = qpos
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
env_state = getter(_FakeData())
self.assertEqual(env_state.shape, (14,))
np.testing.assert_array_equal(
env_state,
np.array(
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
dtype=np.float64,
),
)
class AirInsertRewardAndSuccessTest(unittest.TestCase):
@staticmethod
def _make_env_state(
ring_pos=(0.0, 0.0, 0.50),
ring_quat=(1.0, 0.0, 0.0, 0.0),
bar_pos=(0.0, 0.0, 0.50),
bar_quat=(0.70710678, 0.0, 0.70710678, 0.0),
):
return np.array(
[*ring_pos, *ring_quat, *bar_pos, *bar_quat],
dtype=np.float64,
)
def test_compute_air_insert_reward_counts_left_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
self.assertIsNotNone(
reward_fn,
"Expected roboimi.envs.double_air_insert_env.compute_air_insert_reward",
)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 1)
def test_compute_air_insert_reward_counts_right_contact_stage(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
("bar_block", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 2)
def test_compute_air_insert_reward_counts_lift_stages(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
self.assertEqual(reward, 4)
def test_bar_fully_inserted_through_ring_accepts_true_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertIsNotNone(
success_fn,
"Expected roboimi.envs.double_air_insert_env.bar_fully_inserted_through_ring",
)
self.assertTrue(
success_fn(
self._make_env_state(),
)
)
def test_bar_fully_inserted_through_ring_rejects_centerline_only_false_positive(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0085, 0.0, 0.50)),
)
)
def test_bar_fully_inserted_through_ring_rejects_insufficient_depth(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
success_fn = getattr(air_insert_env, "bar_fully_inserted_through_ring", None)
self.assertFalse(
success_fn(
self._make_env_state(bar_pos=(0.0, 0.0, 0.56)),
)
)
def test_compute_air_insert_reward_requires_airborne_success_for_final_point(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
("ring_block_north", "table"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 3)
def test_compute_air_insert_reward_returns_full_score_on_true_airborne_insert(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
reward_fn = getattr(air_insert_env, "compute_air_insert_reward", None)
reward = reward_fn(
contact_pairs=[
("ring_block_north", "l_finger_left"),
("bar_block", "l_finger_right"),
],
env_state=self._make_env_state(),
)
self.assertEqual(reward, 5)
class AirInsertPolicyAndSmokeTest(unittest.TestCase):
def test_air_insert_policy_emits_valid_16d_action(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(
policy_cls,
"Expected roboimi.demos.diana_air_insert_policy.TestAirInsertPolicy",
)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
policy = policy_cls(inject_noise=False)
action = policy.predict(task_state, 0)
self.assertEqual(action.shape, (16,))
np.testing.assert_array_equal(action[-2:], np.array([100, 100]))
def test_scripted_rollout_entrypoint_selects_ring_bar_sampler_and_policy(self):
rollout_module = importlib.import_module("roboimi.demos.diana_record_sim_episodes")
sampler_fn = getattr(rollout_module, "sample_task_state", None)
policy_factory = getattr(rollout_module, "make_policy", None)
self.assertIsNotNone(
sampler_fn,
"Expected roboimi.demos.diana_record_sim_episodes.sample_task_state",
)
self.assertIsNotNone(
policy_factory,
"Expected roboimi.demos.diana_record_sim_episodes.make_policy",
)
task_state = sampler_fn("sim_air_insert_ring_bar")
self.assertEqual(
list(task_state.keys()),
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
)
policy = policy_factory("sim_air_insert_ring_bar", inject_noise=False)
self.assertEqual(policy.__class__.__name__, "TestAirInsertPolicy")
def test_real_headless_smoke_instantiates_resets_and_steps_new_task_once(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = act_ex_utils.sample_air_insert_ring_bar_state()
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
try:
env.reset(task_state)
action = policy.predict(task_state, 0)
env.step(action)
self.assertIsNotNone(env.obs)
self.assertIn("qpos", env.obs)
self.assertIn("images", env.obs)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
def test_scripted_policy_avoids_cross_arm_contact_on_canonical_insert_case(self):
policy_module = importlib.import_module("roboimi.demos.diana_air_insert_policy")
policy_cls = getattr(policy_module, "TestAirInsertPolicy", None)
self.assertIsNotNone(policy_cls)
task_state = {
"ring_pos": np.array([-0.06658807, 0.93985176, 0.47], dtype=np.float32),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
"bar_pos": np.array([0.12421221, 0.77605027, 0.47], dtype=np.float32),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32),
}
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
policy = policy_cls(inject_noise=False)
def is_cross_arm_pair(a, b):
return ("_left" in a and "_right" in b) or ("_right" in a and "_left" in b)
try:
env.reset(task_state)
for step in range(460):
action = policy.predict(task_state, step)
env.step(action)
pairs = []
for i in range(env.mj_data.ncon):
geom1 = env.getID2Name("geom", env.mj_data.contact[i].geom1)
geom2 = env.getID2Name("geom", env.mj_data.contact[i].geom2)
if geom1 and geom2 and is_cross_arm_pair(geom1, geom2):
pairs.append((geom1, geom2))
self.assertFalse(
pairs,
f"cross-arm contact detected at step {step}: {pairs[:5]}",
)
finally:
env.exit_flag = True
cam_thread = getattr(env, "cam_thread", None)
if cam_thread is not None:
cam_thread.join(timeout=1.0)
viewer = getattr(env, "viewer", None)
if viewer is not None:
viewer.close()
if __name__ == "__main__":
unittest.main()