Files
roboimi/tests/test_streaming_episode_writer.py

80 lines
3.1 KiB
Python

import tempfile
import unittest
from pathlib import Path
import h5py
import numpy as np
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
class StreamingEpisodeWriterTest(unittest.TestCase):
def test_commit_persists_raw_action_and_resized_images(self):
camera_names = ["angle", "r_vis", "top", "front"]
raw_action_0 = np.arange(16, dtype=np.float32)
raw_action_1 = np.arange(16, dtype=np.float32) + 100.0
qpos_0 = np.arange(16, dtype=np.float32) + 200.0
qpos_1 = np.arange(16, dtype=np.float32) + 300.0
with tempfile.TemporaryDirectory() as tmpdir:
episode_path = Path(tmpdir) / "episode_0.hdf5"
writer = StreamingEpisodeWriter(
dataset_path=episode_path,
max_timesteps=2,
camera_names=camera_names,
image_size=(256, 256),
)
writer.append(
qpos=qpos_0,
action=raw_action_0,
images={
cam: np.full((480, 640, 3), fill_value=idx + 1, dtype=np.uint8)
for idx, cam in enumerate(camera_names)
},
)
writer.append(
qpos=qpos_1,
action=raw_action_1,
images={
cam: np.full((480, 640, 3), fill_value=idx + 11, dtype=np.uint8)
for idx, cam in enumerate(camera_names)
},
)
writer.commit()
self.assertTrue(episode_path.exists())
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
with h5py.File(episode_path, "r") as root:
self.assertEqual(root["action"].shape, (2, 16))
self.assertEqual(root["observations/qpos"].shape, (2, 16))
np.testing.assert_allclose(root["action"][0], raw_action_0)
np.testing.assert_allclose(root["action"][1], raw_action_1)
np.testing.assert_allclose(root["observations/qpos"][0], qpos_0)
np.testing.assert_allclose(root["observations/qpos"][1], qpos_1)
for idx, cam_name in enumerate(camera_names):
dataset = root[f"observations/images/{cam_name}"]
self.assertEqual(dataset.shape, (2, 256, 256, 3))
self.assertEqual(dataset.dtype, np.uint8)
self.assertTrue(np.all(dataset[0] == idx + 1))
self.assertTrue(np.all(dataset[1] == idx + 11))
def test_discard_removes_temporary_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
episode_path = Path(tmpdir) / "episode_0.hdf5"
writer = StreamingEpisodeWriter(
dataset_path=episode_path,
max_timesteps=1,
camera_names=["angle", "r_vis", "top", "front"],
image_size=(256, 256),
)
writer.discard()
self.assertFalse(episode_path.exists())
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
if __name__ == "__main__":
unittest.main()