feat(vla): align transformer training stack and rollout validation
This commit is contained in:
63
tests/test_robot_asset_paths.py
Normal file
63
tests/test_robot_asset_paths.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
from roboimi.assets.robots.diana_med import BiDianaMed
|
||||
|
||||
|
||||
class _FakeKDL:
|
||||
init_calls = []
|
||||
reset_calls = []
|
||||
|
||||
def __init__(self, urdf_path):
|
||||
self.__class__.init_calls.append(urdf_path)
|
||||
|
||||
def resetChain(self, base, end):
|
||||
self.__class__.reset_calls.append((base, end))
|
||||
|
||||
|
||||
class RobotAssetPathResolutionTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
_FakeKDL.init_calls = []
|
||||
_FakeKDL.reset_calls = []
|
||||
|
||||
def test_bidianamed_resolves_robot_asset_paths_independent_of_cwd(self):
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
expected_xml = repo_root / 'roboimi/assets/models/manipulators/DianaMed/bi_diana_transfer_ee.xml'
|
||||
expected_urdf = repo_root / 'roboimi/assets/models/manipulators/DianaMed/DualDianaMed.urdf'
|
||||
xml_calls = []
|
||||
|
||||
def fake_from_xml_path(*, filename, assets=None):
|
||||
xml_calls.append((filename, assets))
|
||||
return object()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
previous_cwd = os.getcwd()
|
||||
try:
|
||||
os.chdir(tempdir)
|
||||
with mock.patch(
|
||||
'roboimi.assets.robots.arm_base.mujoco.MjModel.from_xml_path',
|
||||
side_effect=fake_from_xml_path,
|
||||
), mock.patch(
|
||||
'roboimi.assets.robots.arm_base.mujoco.MjData',
|
||||
return_value=object(),
|
||||
), mock.patch(
|
||||
'roboimi.assets.robots.arm_base.KDL_utils',
|
||||
_FakeKDL,
|
||||
):
|
||||
BiDianaMed()
|
||||
finally:
|
||||
os.chdir(previous_cwd)
|
||||
|
||||
self.assertEqual(len(xml_calls), 1)
|
||||
self.assertEqual(Path(xml_calls[0][0]), expected_xml)
|
||||
self.assertTrue(Path(xml_calls[0][0]).is_absolute())
|
||||
self.assertGreaterEqual(len(_FakeKDL.init_calls), 2)
|
||||
self.assertEqual({Path(path) for path in _FakeKDL.init_calls}, {expected_urdf})
|
||||
self.assertTrue(all(Path(path).is_absolute() for path in _FakeKDL.init_calls))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user