import pickle import tempfile import unittest from pathlib import Path import h5py import numpy as np from roboimi.vla.scripts import calculate_stats class CalculateStatsCliTest(unittest.TestCase): def test_default_dataset_dir_is_absolute_and_package_relative(self): expected = ( Path(calculate_stats.__file__).resolve().parents[2] / "demos" / "dataset" / "sim_transfer" ) self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected) self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute()) def test_main_writes_dataset_stats_pkl_to_dataset_dir(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) episode_path = dataset_dir / "episode_0.hdf5" with h5py.File(episode_path, "w") as root: root.create_dataset( "action", data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), ) observations = root.create_group("observations") observations.create_dataset( "qpos", data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32), ) calculate_stats.main(["--dataset_dir", str(dataset_dir)]) stats_path = dataset_dir / "dataset_stats.pkl" self.assertTrue(stats_path.exists()) with stats_path.open("rb") as f: stats = pickle.load(f) self.assertEqual( set(stats), { "action_mean", "action_std", "action_min", "action_max", "qpos_mean", "qpos_std", "qpos_min", "qpos_max", }, ) np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0])) np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0])) def test_main_raises_clear_error_for_empty_dataset_dir(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) with self.assertRaisesRegex( ValueError, r"No episode_\*\.hdf5 files found" ) as ctx: calculate_stats.main(["--dataset_dir", str(dataset_dir)]) self.assertIn(str(dataset_dir), str(ctx.exception)) def test_main_raises_clear_error_for_missing_dataset_dir(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) / "missing" with self.assertRaisesRegex( ValueError, r"No episode_\*\.hdf5 files found" ) as ctx: calculate_stats.main(["--dataset_dir", str(dataset_dir)]) self.assertIn(str(dataset_dir), str(ctx.exception)) if __name__ == "__main__": unittest.main()