feat(vla): align transformer training stack and rollout validation

This commit is contained in:
Logic
2026-03-31 15:39:20 +08:00
parent 424c265823
commit d84bc6876e
25 changed files with 4043 additions and 706 deletions

View File

@@ -0,0 +1,88 @@
import pickle
import tempfile
import unittest
from pathlib import Path
import h5py
import numpy as np
from roboimi.vla.scripts import calculate_stats
class CalculateStatsCliTest(unittest.TestCase):
def test_default_dataset_dir_is_absolute_and_package_relative(self):
expected = (
Path(calculate_stats.__file__).resolve().parents[2]
/ "demos"
/ "dataset"
/ "sim_transfer"
)
self.assertEqual(Path(calculate_stats.DEFAULT_DATASET_DIR), expected)
self.assertTrue(Path(calculate_stats.DEFAULT_DATASET_DIR).is_absolute())
def test_main_writes_dataset_stats_pkl_to_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
episode_path = dataset_dir / "episode_0.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset(
"action",
data=np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32),
)
observations = root.create_group("observations")
observations.create_dataset(
"qpos",
data=np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32),
)
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
stats_path = dataset_dir / "dataset_stats.pkl"
self.assertTrue(stats_path.exists())
with stats_path.open("rb") as f:
stats = pickle.load(f)
self.assertEqual(
set(stats),
{
"action_mean",
"action_std",
"action_min",
"action_max",
"qpos_mean",
"qpos_std",
"qpos_min",
"qpos_max",
},
)
np.testing.assert_allclose(stats["action_mean"], np.array([2.0, 3.0]))
np.testing.assert_allclose(stats["qpos_mean"], np.array([6.0, 7.0]))
def test_main_raises_clear_error_for_empty_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
with self.assertRaisesRegex(
ValueError, r"No episode_\*\.hdf5 files found"
) as ctx:
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
self.assertIn(str(dataset_dir), str(ctx.exception))
def test_main_raises_clear_error_for_missing_dataset_dir(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir) / "missing"
with self.assertRaisesRegex(
ValueError, r"No episode_\*\.hdf5 files found"
) as ctx:
calculate_stats.main(["--dataset_dir", str(dataset_dir)])
self.assertIn(str(dataset_dir), str(ctx.exception))
if __name__ == "__main__":
unittest.main()