179 lines
6.8 KiB
Python
179 lines
6.8 KiB
Python
import importlib
|
|
import unittest
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
|
|
from roboimi.envs.double_pos_ctrl_env import make_sim_env
|
|
from roboimi.utils import act_ex_utils
|
|
from roboimi.utils.constants import SIM_TASK_CONFIGS
|
|
|
|
|
|
class AirInsertTaskRegistrationTest(unittest.TestCase):
|
|
def test_sim_task_configs_registers_air_insert_ring_bar(self):
|
|
self.assertIn("sim_air_insert_ring_bar", SIM_TASK_CONFIGS)
|
|
|
|
def test_sample_air_insert_ring_bar_state_returns_explicit_named_mapping(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
|
self.assertIsNotNone(
|
|
sampler,
|
|
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
|
)
|
|
|
|
task_state = sampler()
|
|
|
|
self.assertEqual(
|
|
list(task_state.keys()),
|
|
["ring_pos", "ring_quat", "bar_pos", "bar_quat"],
|
|
)
|
|
self.assertEqual(task_state["ring_pos"].shape, (3,))
|
|
self.assertEqual(task_state["ring_quat"].shape, (4,))
|
|
self.assertEqual(task_state["bar_pos"].shape, (3,))
|
|
self.assertEqual(task_state["bar_quat"].shape, (4,))
|
|
|
|
def test_sample_air_insert_ring_bar_state_uses_fixed_quats_and_left_right_planar_ranges(self):
|
|
sampler = getattr(act_ex_utils, "sample_air_insert_ring_bar_state", None)
|
|
self.assertIsNotNone(
|
|
sampler,
|
|
"Expected roboimi.utils.act_ex_utils.sample_air_insert_ring_bar_state()",
|
|
)
|
|
|
|
task_state = sampler()
|
|
|
|
np.testing.assert_array_equal(task_state["ring_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
|
np.testing.assert_array_equal(task_state["bar_quat"], np.array([1.0, 0.0, 0.0, 0.0]))
|
|
self.assertGreaterEqual(task_state["ring_pos"][0], -0.20)
|
|
self.assertLessEqual(task_state["ring_pos"][0], -0.05)
|
|
self.assertGreaterEqual(task_state["ring_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["ring_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["ring_pos"][2]), 0.47)
|
|
self.assertGreaterEqual(task_state["bar_pos"][0], 0.05)
|
|
self.assertLessEqual(task_state["bar_pos"][0], 0.20)
|
|
self.assertGreaterEqual(task_state["bar_pos"][1], 0.70)
|
|
self.assertLessEqual(task_state["bar_pos"][1], 1.00)
|
|
self.assertAlmostEqual(float(task_state["bar_pos"][2]), 0.47)
|
|
|
|
def test_make_sim_env_dispatches_air_insert_ring_bar_headless(self):
|
|
try:
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
except Exception as exc:
|
|
self.fail(f"Expected roboimi.envs.double_air_insert_env to be importable: {exc}")
|
|
|
|
air_insert_cls = getattr(air_insert_env, "DualDianaMed_Air_Insert", None)
|
|
self.assertIsNotNone(
|
|
air_insert_cls,
|
|
"Expected roboimi.envs.double_air_insert_env.DualDianaMed_Air_Insert",
|
|
)
|
|
|
|
diana_med = importlib.import_module("roboimi.assets.robots.diana_med")
|
|
ring_bar_robot_cls = getattr(diana_med, "BiDianaMedRingBar", None)
|
|
self.assertIsNotNone(
|
|
ring_bar_robot_cls,
|
|
"Expected roboimi.assets.robots.diana_med.BiDianaMedRingBar",
|
|
)
|
|
|
|
fake_env = object()
|
|
with mock.patch.object(
|
|
diana_med,
|
|
"BiDianaMedRingBar",
|
|
return_value="robot",
|
|
), mock.patch.object(
|
|
air_insert_env,
|
|
"DualDianaMed_Air_Insert",
|
|
return_value=fake_env,
|
|
) as env_cls:
|
|
try:
|
|
env = make_sim_env("sim_air_insert_ring_bar", headless=True)
|
|
except Exception as exc:
|
|
self.fail(f"make_sim_env should dispatch sim_air_insert_ring_bar without error: {exc}")
|
|
|
|
self.assertIs(env, fake_env)
|
|
env_cls.assert_called_once_with(
|
|
robot="robot",
|
|
is_render=False,
|
|
control_freq=30,
|
|
is_interpolate=True,
|
|
cam_view="angle",
|
|
)
|
|
|
|
|
|
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
|
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
|
|
self.assertIsNotNone(
|
|
setter,
|
|
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
|
|
)
|
|
|
|
ring_qpos = np.zeros(7, dtype=np.float64)
|
|
bar_qpos = np.zeros(7, dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "ring_block_joint":
|
|
return _FakeJoint(ring_qpos)
|
|
if name == "bar_block_joint":
|
|
return _FakeJoint(bar_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
task_state = {
|
|
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
|
|
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
|
|
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
}
|
|
|
|
setter(_FakeData(), task_state)
|
|
|
|
np.testing.assert_array_equal(
|
|
ring_qpos,
|
|
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
np.testing.assert_array_equal(
|
|
bar_qpos,
|
|
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
|
)
|
|
|
|
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
|
|
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
|
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
|
|
self.assertIsNotNone(
|
|
getter,
|
|
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
|
|
)
|
|
|
|
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
|
|
|
class _FakeJoint:
|
|
def __init__(self, qpos):
|
|
self.qpos = qpos
|
|
|
|
class _FakeData:
|
|
def joint(self, name):
|
|
if name == "ring_block_joint":
|
|
return _FakeJoint(ring_qpos)
|
|
if name == "bar_block_joint":
|
|
return _FakeJoint(bar_qpos)
|
|
raise AssertionError(f"Unexpected joint name: {name}")
|
|
|
|
env_state = getter(_FakeData())
|
|
|
|
self.assertEqual(env_state.shape, (14,))
|
|
np.testing.assert_array_equal(
|
|
env_state,
|
|
np.array(
|
|
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|