feat(vla): align transformer training stack and rollout validation
This commit is contained in:
88
tests/test_calculate_stats_cli.py
Normal file
88
tests/test_calculate_stats_cli.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pickle
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from roboimi.vla.scripts import calculate_stats
|
||||
|
||||
|
||||
class CalculateStatsCliTest(unittest.TestCase):
|
||||
def test_default_dataset_dir_is_absolute_and_package_relative(self):
|
||||
expected = (
|
||||
Path(calculate_stats.__file__).resolve().parents[2]
|
||||
/ "demos"
|
||||
/ "dataset"
|
||||
/ "sim_transfer"
|
||||
)
|
||||
|
||||
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
|
||||
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
|
||||
|
||||
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
episode_path = dataset_dir / "episode_0.hdf5"
|
||||
|
||||
with h5py.File(episode_path, "w") as root:
|
||||
root.create_dataset(
|
||||
"action",
|
||||
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
|
||||
)
|
||||
observations = root.create_group("observations")
|
||||
observations.create_dataset(
|
||||
"qpos",
|
||||
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
|
||||
)
|
||||
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
stats_path = dataset_dir / "dataset_stats.pkl"
|
||||
self.assertTrue(stats_path.exists())
|
||||
|
||||
with stats_path.open("rb") as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
self.assertEqual(
|
||||
set(stats),
|
||||
{
|
||||
"action_mean",
|
||||
"action_std",
|
||||
"action_min",
|
||||
"action_max",
|
||||
"qpos_mean",
|
||||
"qpos_std",
|
||||
"qpos_min",
|
||||
"qpos_max",
|
||||
},
|
||||
)
|
||||
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
|
||||
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
|
||||
|
||||
def test_main_raises_clear_error_for_empty_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"No episode_\*\.hdf5 files found"
|
||||
) as ctx:
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||
|
||||
def test_main_raises_clear_error_for_missing_dataset_dir(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir) / "missing"
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"No episode_\*\.hdf5 files found"
|
||||
) as ctx:
|
||||
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
||||
|
||||
self.assertIn(str(dataset_dir), str(ctx.exception))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user