feat(scene): add ring and bar insertion scene assets
This commit is contained in:
@@ -97,5 +97,82 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
|
||||
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
|
||||
self.assertIsNotNone(
|
||||
setter,
|
||||
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
|
||||
)
|
||||
|
||||
ring_qpos = np.zeros(7, dtype=np.float64)
|
||||
bar_qpos = np.zeros(7, dtype=np.float64)
|
||||
|
||||
class _FakeJoint:
|
||||
def __init__(self, qpos):
|
||||
self.qpos = qpos
|
||||
|
||||
class _FakeData:
|
||||
def joint(self, name):
|
||||
if name == "ring_block_joint":
|
||||
return _FakeJoint(ring_qpos)
|
||||
if name == "bar_block_joint":
|
||||
return _FakeJoint(bar_qpos)
|
||||
raise AssertionError(f"Unexpected joint name: {name}")
|
||||
|
||||
task_state = {
|
||||
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
|
||||
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
|
||||
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
}
|
||||
|
||||
setter(_FakeData(), task_state)
|
||||
|
||||
np.testing.assert_array_equal(
|
||||
ring_qpos,
|
||||
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
)
|
||||
np.testing.assert_array_equal(
|
||||
bar_qpos,
|
||||
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
|
||||
)
|
||||
|
||||
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
|
||||
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
|
||||
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
|
||||
self.assertIsNotNone(
|
||||
getter,
|
||||
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
|
||||
)
|
||||
|
||||
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
|
||||
|
||||
class _FakeJoint:
|
||||
def __init__(self, qpos):
|
||||
self.qpos = qpos
|
||||
|
||||
class _FakeData:
|
||||
def joint(self, name):
|
||||
if name == "ring_block_joint":
|
||||
return _FakeJoint(ring_qpos)
|
||||
if name == "bar_block_joint":
|
||||
return _FakeJoint(bar_qpos)
|
||||
raise AssertionError(f"Unexpected joint name: {name}")
|
||||
|
||||
env_state = getter(_FakeData())
|
||||
|
||||
self.assertEqual(env_state.shape, (14,))
|
||||
np.testing.assert_array_equal(
|
||||
env_state,
|
||||
np.array(
|
||||
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
|
||||
dtype=np.float64,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user