89 lines
2.9 KiB
Python
89 lines
2.9 KiB
Python
import pickle
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import h5py
|
|
import numpy as np
|
|
|
|
from roboimi.vla.scripts import calculate_stats
|
|
|
|
|
|
class CalculateStatsCliTest(unittest.TestCase):
|
|
def test_default_dataset_dir_is_absolute_and_package_relative(self):
|
|
expected = (
|
|
Path(calculate_stats.__file__).resolve().parents[2]
|
|
/ "demos"
|
|
/ "dataset"
|
|
/ "sim_transfer"
|
|
)
|
|
|
|
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
|
|
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
|
|
|
|
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
episode_path = dataset_dir / "episode_0.hdf5"
|
|
|
|
with h5py.File(episode_path, "w") as root:
|
|
root.create_dataset(
|
|
"action",
|
|
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
|
|
)
|
|
observations = root.create_group("observations")
|
|
observations.create_dataset(
|
|
"qpos",
|
|
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
|
|
)
|
|
|
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
|
|
|
stats_path = dataset_dir / "dataset_stats.pkl"
|
|
self.assertTrue(stats_path.exists())
|
|
|
|
with stats_path.open("rb") as f:
|
|
stats = pickle.load(f)
|
|
|
|
self.assertEqual(
|
|
set(stats),
|
|
{
|
|
"action_mean",
|
|
"action_std",
|
|
"action_min",
|
|
"action_max",
|
|
"qpos_mean",
|
|
"qpos_std",
|
|
"qpos_min",
|
|
"qpos_max",
|
|
},
|
|
)
|
|
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
|
|
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
|
|
|
|
def test_main_raises_clear_error_for_empty_dataset_dir(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"No episode_\*\.hdf5 files found"
|
|
) as ctx:
|
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
|
|
|
self.assertIn(str(dataset_dir), str(ctx.exception))
|
|
|
|
def test_main_raises_clear_error_for_missing_dataset_dir(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir) / "missing"
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"No episode_\*\.hdf5 files found"
|
|
) as ctx:
|
|
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
|
|
|
|
self.assertIn(str(dataset_dir), str(ctx.exception))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|