128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import sys
|
|
import tempfile
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import h5py
|
|
import numpy as np
|
|
|
|
from roboimi.vla.data.simpe_robot_dataset import SimpleRobotDataset
|
|
|
|
|
|
class SimpleRobotDatasetImageLoadingTest(unittest.TestCase):
|
|
def _write_episode(self, dataset_dir: Path, episode_idx: int = 0, *, base_value: float = 0.0) -> None:
|
|
episode_path = dataset_dir / f"episode_{episode_idx}.hdf5"
|
|
with h5py.File(episode_path, "w") as root:
|
|
root.create_dataset(
|
|
"action",
|
|
data=(np.arange(8, dtype=np.float32).reshape(4, 2) + base_value),
|
|
)
|
|
root.create_dataset(
|
|
"observations/qpos",
|
|
data=(np.arange(16, dtype=np.float32).reshape(4, 4) + base_value),
|
|
)
|
|
root.create_dataset("task", data=np.array([b"sim_transfer"]))
|
|
root.create_dataset(
|
|
"observations/images/front",
|
|
data=((np.arange(4 * 8 * 8 * 3, dtype=np.uint8) + int(base_value)) % 255).reshape(4, 8, 8, 3),
|
|
)
|
|
|
|
def test_getitem_only_resizes_observation_horizon_images(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
self._write_episode(dataset_dir)
|
|
dataset = SimpleRobotDataset(
|
|
dataset_dir,
|
|
obs_horizon=2,
|
|
pred_horizon=3,
|
|
camera_names=["front"],
|
|
)
|
|
|
|
resize_calls = []
|
|
|
|
def fake_resize(image, size, interpolation=None):
|
|
resize_calls.append(
|
|
{
|
|
"shape": tuple(image.shape),
|
|
"size": size,
|
|
"interpolation": interpolation,
|
|
}
|
|
)
|
|
return image
|
|
|
|
fake_cv2 = types.SimpleNamespace(INTER_LINEAR=1, resize=fake_resize)
|
|
|
|
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
|
sample = dataset[1]
|
|
|
|
self.assertEqual(len(resize_calls), 2)
|
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
|
|
|
def test_getitem_skips_resize_when_image_resize_shape_is_none(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
self._write_episode(dataset_dir)
|
|
dataset = SimpleRobotDataset(
|
|
dataset_dir,
|
|
obs_horizon=2,
|
|
pred_horizon=3,
|
|
camera_names=["front"],
|
|
image_resize_shape=None,
|
|
)
|
|
|
|
fake_cv2 = types.SimpleNamespace(
|
|
INTER_LINEAR=1,
|
|
resize=mock.Mock(side_effect=AssertionError("resize should be skipped when image_resize_shape=None")),
|
|
)
|
|
|
|
with mock.patch.dict(sys.modules, {"cv2": fake_cv2}):
|
|
sample = dataset[1]
|
|
|
|
fake_cv2.resize.assert_not_called()
|
|
self.assertEqual(tuple(sample["observation.front"].shape), (2, 3, 8, 8))
|
|
|
|
def test_getitem_can_emit_lewm_history_and_future_observations(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
self._write_episode(dataset_dir)
|
|
dataset = SimpleRobotDataset(
|
|
dataset_dir,
|
|
obs_horizon=2,
|
|
pred_horizon=3,
|
|
camera_names=["front"],
|
|
image_resize_shape=None,
|
|
lewm_history_horizon=3,
|
|
lewm_query_offsets=[1, 2],
|
|
)
|
|
|
|
sample = dataset[1]
|
|
|
|
self.assertEqual(tuple(sample["lewm.observation.state"].shape), (3, 4))
|
|
self.assertEqual(tuple(sample["lewm.observation.front"].shape), (3, 3, 8, 8))
|
|
self.assertEqual(tuple(sample["lewm.future.state"].shape), (2, 4))
|
|
self.assertEqual(tuple(sample["lewm.future.front"].shape), (2, 3, 8, 8))
|
|
|
|
def test_dataset_can_limit_loading_to_specific_episode_indices(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
dataset_dir = Path(tmpdir)
|
|
self._write_episode(dataset_dir, episode_idx=0, base_value=0.0)
|
|
self._write_episode(dataset_dir, episode_idx=1, base_value=100.0)
|
|
|
|
dataset = SimpleRobotDataset(
|
|
dataset_dir,
|
|
obs_horizon=2,
|
|
pred_horizon=3,
|
|
camera_names=["front"],
|
|
image_resize_shape=None,
|
|
episode_indices=[1],
|
|
)
|
|
|
|
sample = dataset[0]
|
|
|
|
self.assertEqual(len(dataset.hdf5_files), 1)
|
|
self.assertEqual(dataset.available_episode_indices, [1])
|
|
self.assertEqual(len(dataset), 4)
|
|
self.assertTrue(np.allclose(sample["observation.state"][0].numpy(), np.array([100.0, 101.0, 102.0, 103.0])))
|