80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import h5py
|
|
import numpy as np
|
|
|
|
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
|
|
|
|
|
|
class StreamingEpisodeWriterTest(unittest.TestCase):
|
|
def test_commit_persists_raw_action_and_resized_images(self):
|
|
camera_names = ["angle", "r_vis", "top", "front"]
|
|
raw_action_0 = np.arange(16, dtype=np.float32)
|
|
raw_action_1 = np.arange(16, dtype=np.float32) + 100.0
|
|
qpos_0 = np.arange(16, dtype=np.float32) + 200.0
|
|
qpos_1 = np.arange(16, dtype=np.float32) + 300.0
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
episode_path = Path(tmpdir) / "episode_0.hdf5"
|
|
writer = StreamingEpisodeWriter(
|
|
dataset_path=episode_path,
|
|
max_timesteps=2,
|
|
camera_names=camera_names,
|
|
image_size=(256, 256),
|
|
)
|
|
|
|
writer.append(
|
|
qpos=qpos_0,
|
|
action=raw_action_0,
|
|
images={
|
|
cam: np.full((480, 640, 3), fill_value=idx + 1, dtype=np.uint8)
|
|
for idx, cam in enumerate(camera_names)
|
|
},
|
|
)
|
|
writer.append(
|
|
qpos=qpos_1,
|
|
action=raw_action_1,
|
|
images={
|
|
cam: np.full((480, 640, 3), fill_value=idx + 11, dtype=np.uint8)
|
|
for idx, cam in enumerate(camera_names)
|
|
},
|
|
)
|
|
writer.commit()
|
|
|
|
self.assertTrue(episode_path.exists())
|
|
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
|
|
|
|
with h5py.File(episode_path, "r") as root:
|
|
self.assertEqual(root["action"].shape, (2, 16))
|
|
self.assertEqual(root["observations/qpos"].shape, (2, 16))
|
|
np.testing.assert_allclose(root["action"][0], raw_action_0)
|
|
np.testing.assert_allclose(root["action"][1], raw_action_1)
|
|
np.testing.assert_allclose(root["observations/qpos"][0], qpos_0)
|
|
np.testing.assert_allclose(root["observations/qpos"][1], qpos_1)
|
|
for idx, cam_name in enumerate(camera_names):
|
|
dataset = root[f"observations/images/{cam_name}"]
|
|
self.assertEqual(dataset.shape, (2, 256, 256, 3))
|
|
self.assertEqual(dataset.dtype, np.uint8)
|
|
self.assertTrue(np.all(dataset[0] == idx + 1))
|
|
self.assertTrue(np.all(dataset[1] == idx + 11))
|
|
|
|
def test_discard_removes_temporary_file(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
episode_path = Path(tmpdir) / "episode_0.hdf5"
|
|
writer = StreamingEpisodeWriter(
|
|
dataset_path=episode_path,
|
|
max_timesteps=1,
|
|
camera_names=["angle", "r_vis", "top", "front"],
|
|
image_size=(256, 256),
|
|
)
|
|
writer.discard()
|
|
|
|
self.assertFalse(episode_path.exists())
|
|
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|