import tempfile import unittest from pathlib import Path import h5py import numpy as np from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter class StreamingEpisodeWriterTest(unittest.TestCase): def test_commit_persists_raw_action_and_resized_images(self): camera_names = ["angle", "r_vis", "top", "front"] raw_action_0 = np.arange(16, dtype=np.float32) raw_action_1 = np.arange(16, dtype=np.float32) + 100.0 qpos_0 = np.arange(16, dtype=np.float32) + 200.0 qpos_1 = np.arange(16, dtype=np.float32) + 300.0 with tempfile.TemporaryDirectory() as tmpdir: episode_path = Path(tmpdir) / "episode_0.hdf5" writer = StreamingEpisodeWriter( dataset_path=episode_path, max_timesteps=2, camera_names=camera_names, image_size=(256, 256), ) writer.append( qpos=qpos_0, action=raw_action_0, images={ cam: np.full((480, 640, 3), fill_value=idx + 1, dtype=np.uint8) for idx, cam in enumerate(camera_names) }, ) writer.append( qpos=qpos_1, action=raw_action_1, images={ cam: np.full((480, 640, 3), fill_value=idx + 11, dtype=np.uint8) for idx, cam in enumerate(camera_names) }, ) writer.commit() self.assertTrue(episode_path.exists()) self.assertFalse(Path(str(episode_path) + ".tmp").exists()) with h5py.File(episode_path, "r") as root: self.assertEqual(root["action"].shape, (2, 16)) self.assertEqual(root["observations/qpos"].shape, (2, 16)) np.testing.assert_allclose(root["action"][0], raw_action_0) np.testing.assert_allclose(root["action"][1], raw_action_1) np.testing.assert_allclose(root["observations/qpos"][0], qpos_0) np.testing.assert_allclose(root["observations/qpos"][1], qpos_1) for idx, cam_name in enumerate(camera_names): dataset = root[f"observations/images/{cam_name}"] self.assertEqual(dataset.shape, (2, 256, 256, 3)) self.assertEqual(dataset.dtype, np.uint8) self.assertTrue(np.all(dataset[0] == idx + 1)) self.assertTrue(np.all(dataset[1] == idx + 11)) def test_discard_removes_temporary_file(self): with tempfile.TemporaryDirectory() as tmpdir: episode_path = Path(tmpdir) / "episode_0.hdf5" writer = StreamingEpisodeWriter( dataset_path=episode_path, max_timesteps=1, camera_names=["angle", "r_vis", "top", "front"], image_size=(256, 256), ) writer.discard() self.assertFalse(episode_path.exists()) self.assertFalse(Path(str(episode_path) + ".tmp").exists()) if __name__ == "__main__": unittest.main()