feat(vla): align transformer training stack and rollout validation
This commit is contained in:
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
58
tests/test_simple_robot_dataset_image_loading.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
||||
|
||||
|
||||
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
||||
def _write_episode(self, dataset_dir: Path) -> None:
|
||||
episode_path = dataset_dir / "episode_0.hdf5"
|
||||
with h5py.File(episode_path, "w") as root:
|
||||
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
|
||||
root.create_dataset(
|
||||
"observations/qpos",
|
||||
data=np.arange(16, dtype=np.float32).reshape(4, 4),
|
||||
)
|
||||
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
||||
root.create_dataset(
|
||||
"observations/images/front",
|
||||
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
|
||||
)
|
||||
|
||||
def test_getitem_only_resizes_observation_horizon_images(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dataset_dir = Path(tmpdir)
|
||||
self._write_episode(dataset_dir)
|
||||
dataset = SimpleRobotDataset(
|
||||
dataset_dir,
|
||||
obs_horizon=2,
|
||||
pred_horizon=3,
|
||||
camera_names=["front"],
|
||||
)
|
||||
|
||||
resize_calls = []
|
||||
|
||||
def fake_resize(image, size, interpolation=None):
|
||||
resize_calls.append(
|
||||
{
|
||||
"shape": tuple(image.shape),
|
||||
"size": size,
|
||||
"interpolation": interpolation,
|
||||
}
|
||||
)
|
||||
return image
|
||||
|
||||
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
|
||||
|
||||
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
||||
sample = dataset[1]
|
||||
|
||||
self.assertEqual(len(resize_calls), 2)
|
||||
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
||||
Reference in New Issue
Block a user