Files
roboimi/tests/test_simple_robot_dataset_image_loading.py

103 lines
3.7 KiB
Python

import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest import mock
import h5py
import numpy as np
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
def _write_episode(self, dataset_dir: Path) -> None:
episode_path = dataset_dir / "episode_0.hdf5"
with h5py.File(episode_path, "w") as root:
root.create_dataset("action", data=np.arange(8, dtype=np.float32).reshape(4, 2))
root.create_dataset(
"observations/qpos",
data=np.arange(16, dtype=np.float32).reshape(4, 4),
)
root.create_dataset("task", data=np.array([b"sim_transfer"]))
root.create_dataset(
"observations/images/front",
data=np.arange(4 * 8 * 8 * 3, dtype=np.uint8).reshape(4, 8, 8, 3),
)
def test_getitem_only_resizes_observation_horizon_images(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
)
resize_calls = []
def fake_resize(image, size, interpolation=None):
resize_calls.append(
{
"shape": tuple(image.shape),
"size": size,
"interpolation": interpolation,
}
)
return image
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
sample = dataset[1]
self.assertEqual(len(resize_calls), 2)
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
)
fake_cv2 = types.SimpleNamespace(
INTER_LINEAR=1,
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
)
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
sample = dataset[1]
fake_cv2.resize.assert_not_called()
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
def test_getitem_can_emit_lewm_history_and_future_observations(self):
with tempfile.TemporaryDirectory() as tmpdir:
dataset_dir = Path(tmpdir)
self._write_episode(dataset_dir)
dataset = SimpleRobotDataset(
dataset_dir,
obs_horizon=2,
pred_horizon=3,
camera_names=["front"],
image_resize_shape=None,
lewm_history_horizon=3,
lewm_query_offsets=[1, 2],
)
sample = dataset[1]
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))