import sys import tempfile import types import unittest from pathlib import Path from unittest import mock import h5py import numpy as np from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset class SimpleRobotDatasetImageLoadingTest(unittest.TestCase): def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None: episode_path = dataset_dir / f"episode_{episode_idx}.hdf5" with h5py.File(episode_path, "w") as root: root.create_dataset( "action", data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value), ) root.create_dataset( "observations/qpos", data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value), ) root.create_dataset("task", data=np.array([b"sim_transfer"])) root.create_dataset( "observations/images/front", data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3), ) def test_getitem_only_resizes_observation_horizon_images(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) self._write_episode(dataset_dir) dataset = SimpleRobotDataset( dataset_dir, obs_horizon=2, pred_horizon=3, camera_names=["front"], ) resize_calls = [] def fake_resize(image, size, interpolation=None): resize_calls.append( { "shape": tuple(image.shape), "size": size, "interpolation": interpolation, } ) return image fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize) with mock.patch.dict(sys.modules, {"cv2": fake_cv2}): sample = dataset[1] self.assertEqual(len(resize_calls), 2) self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8)) def test_getitem_skips_resize_when_image_resize_shape_is_none(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) self._write_episode(dataset_dir) dataset = SimpleRobotDataset( dataset_dir, obs_horizon=2, pred_horizon=3, camera_names=["front"], image_resize_shape=None, ) fake_cv2 = types.SimpleNamespace( INTER_LINEAR=1, resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")), ) with mock.patch.dict(sys.modules, {"cv2": fake_cv2}): sample = dataset[1] fake_cv2.resize.assert_not_called() self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8)) def test_getitem_can_emit_lewm_history_and_future_observations(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) self._write_episode(dataset_dir) dataset = SimpleRobotDataset( dataset_dir, obs_horizon=2, pred_horizon=3, camera_names=["front"], image_resize_shape=None, lewm_history_horizon=3, lewm_query_offsets=[1, 2], ) sample = dataset[1] self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4)) self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8)) self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4)) self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8)) def test_dataset_can_limit_loading_to_specific_episode_indices(self): with tempfile.TemporaryDirectory() as tmpdir: dataset_dir = Path(tmpdir) self._write_episode(dataset_dir, episode_idx=0, base_value=0.0) self._write_episode(dataset_dir, episode_idx=1, base_value=100.0) dataset = SimpleRobotDataset( dataset_dir, obs_horizon=2, pred_horizon=3, camera_names=["front"], image_resize_shape=None, episode_indices=[1], ) sample = dataset[0] self.assertEqual(len(dataset.hdf5_files), 1) self.assertEqual(dataset.available_episode_indices, [1]) self.assertEqual(len(dataset), 4) self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))