Files
roboimi/roboimi/utils/streaming_episode_writer.py

114 lines
4.1 KiB
Python

from __future__ import annotations
import os
from pathlib import Path
import cv2
import h5py
import numpy as np
class StreamingEpisodeWriter:
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
def __init__(
self,
dataset_path: str | os.PathLike[str],
max_timesteps: int,
camera_names: list[str],
image_size: tuple[int, int] = (256, 256),
) -> None:
self.dataset_path = Path(dataset_path)
self.tmp_path = Path(f"{self.dataset_path}.tmp")
self.max_timesteps = int(max_timesteps)
self.camera_names = list(camera_names)
self.image_height = int(image_size[0])
self.image_width = int(image_size[1])
self.frame_index = 0
self._committed = False
self._closed = False
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
if self.tmp_path.exists():
self.tmp_path.unlink()
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
self._file.attrs["sim"] = True
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
self._file.attrs["image_height"] = self.image_height
self._file.attrs["image_width"] = self.image_width
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
observations = self._file.create_group("observations")
images = observations.create_group("images")
for cam_name in self.camera_names:
images.create_dataset(
cam_name,
(self.max_timesteps, self.image_height, self.image_width, 3),
dtype="uint8",
chunks=(1, self.image_height, self.image_width, 3),
)
observations.create_dataset(
"qpos",
(self.max_timesteps, 16),
dtype="float32",
chunks=(min(128, self.max_timesteps), 16),
)
self._file.create_dataset(
"action",
(self.max_timesteps, 16),
dtype="float32",
chunks=(min(128, self.max_timesteps), 16),
)
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
if self._closed:
raise RuntimeError("writer is already closed")
if self.frame_index >= self.max_timesteps:
raise IndexError("frame index exceeds max_timesteps")
qpos = np.asarray(qpos, dtype=np.float32)
action = np.asarray(action, dtype=np.float32)
if qpos.shape != (16,):
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
if action.shape != (16,):
raise ValueError(f"action shape must be (16,), got {action.shape}")
self._file["observations/qpos"][self.frame_index] = qpos
self._file["action"][self.frame_index] = action
for cam_name in self.camera_names:
if cam_name not in images:
raise KeyError(f"missing image for camera '{cam_name}'")
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
self.frame_index += 1
def commit(self) -> None:
if self._closed:
return
self._file.flush()
self._file.close()
self._closed = True
os.replace(self.tmp_path, self.dataset_path)
self._committed = True
def discard(self) -> None:
if not self._closed:
self._file.close()
self._closed = True
if self.tmp_path.exists():
self.tmp_path.unlink()
def _resize_image(self, image: np.ndarray) -> np.ndarray:
image = np.asarray(image, dtype=np.uint8)
if image.ndim != 3 or image.shape[2] != 3:
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
if image.shape[:2] == (self.image_height, self.image_width):
return image
interpolation = cv2.INTER_AREA
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
interpolation = cv2.INTER_LINEAR
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)