feat(scene): add ring and bar insertion scene assets

This commit is contained in:
Logic
2026-04-23 17:32:43 +08:00
parent 06ac6c6d18
commit f1ede7690f
4 changed files with 175 additions and 6 deletions

View File

@@ -97,5 +97,82 @@ class AirInsertTaskRegistrationTest(unittest.TestCase):
)
class AirInsertResetAndStateHelpersTest(unittest.TestCase):
def test_set_ring_bar_task_state_writes_free_joint_qpos(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
setter = getattr(air_insert_env, "set_ring_bar_task_state", None)
self.assertIsNotNone(
setter,
"Expected roboimi.envs.double_air_insert_env.set_ring_bar_task_state",
)
ring_qpos = np.zeros(7, dtype=np.float64)
bar_qpos = np.zeros(7, dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
self.qpos = qpos
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
task_state = {
"ring_pos": np.array([-0.12, 0.90, 0.47], dtype=np.float64),
"ring_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
"bar_pos": np.array([0.12, 0.91, 0.47], dtype=np.float64),
"bar_quat": np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float64),
}
setter(_FakeData(), task_state)
np.testing.assert_array_equal(
ring_qpos,
np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
np.testing.assert_array_equal(
bar_qpos,
np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64),
)
def test_get_ring_bar_env_state_returns_stable_14d_vector(self):
air_insert_env = importlib.import_module("roboimi.envs.double_air_insert_env")
getter = getattr(air_insert_env, "get_ring_bar_env_state", None)
self.assertIsNotNone(
getter,
"Expected roboimi.envs.double_air_insert_env.get_ring_bar_env_state",
)
ring_qpos = np.array([-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
bar_qpos = np.array([0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0], dtype=np.float64)
class _FakeJoint:
def __init__(self, qpos):
self.qpos = qpos
class _FakeData:
def joint(self, name):
if name == "ring_block_joint":
return _FakeJoint(ring_qpos)
if name == "bar_block_joint":
return _FakeJoint(bar_qpos)
raise AssertionError(f"Unexpected joint name: {name}")
env_state = getter(_FakeData())
self.assertEqual(env_state.shape, (14,))
np.testing.assert_array_equal(
env_state,
np.array(
[-0.12, 0.90, 0.47, 1.0, 0.0, 0.0, 0.0, 0.12, 0.91, 0.47, 1.0, 0.0, 0.0, 0.0],
dtype=np.float64,
),
)
if __name__ == "__main__":
unittest.main()