114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
import h5py
|
|
import numpy as np
|
|
|
|
|
|
class StreamingEpisodeWriter:
|
|
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_path: str | os.PathLike[str],
|
|
max_timesteps: int,
|
|
camera_names: list[str],
|
|
image_size: tuple[int, int] = (256, 256),
|
|
) -> None:
|
|
self.dataset_path = Path(dataset_path)
|
|
self.tmp_path = Path(f"{self.dataset_path}.tmp")
|
|
self.max_timesteps = int(max_timesteps)
|
|
self.camera_names = list(camera_names)
|
|
self.image_height = int(image_size[0])
|
|
self.image_width = int(image_size[1])
|
|
self.frame_index = 0
|
|
self._committed = False
|
|
self._closed = False
|
|
|
|
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if self.tmp_path.exists():
|
|
self.tmp_path.unlink()
|
|
|
|
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
|
|
self._file.attrs["sim"] = True
|
|
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
|
|
self._file.attrs["image_height"] = self.image_height
|
|
self._file.attrs["image_width"] = self.image_width
|
|
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
|
|
|
|
observations = self._file.create_group("observations")
|
|
images = observations.create_group("images")
|
|
for cam_name in self.camera_names:
|
|
images.create_dataset(
|
|
cam_name,
|
|
(self.max_timesteps, self.image_height, self.image_width, 3),
|
|
dtype="uint8",
|
|
chunks=(1, self.image_height, self.image_width, 3),
|
|
)
|
|
observations.create_dataset(
|
|
"qpos",
|
|
(self.max_timesteps, 16),
|
|
dtype="float32",
|
|
chunks=(min(128, self.max_timesteps), 16),
|
|
)
|
|
self._file.create_dataset(
|
|
"action",
|
|
(self.max_timesteps, 16),
|
|
dtype="float32",
|
|
chunks=(min(128, self.max_timesteps), 16),
|
|
)
|
|
|
|
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
|
|
if self._closed:
|
|
raise RuntimeError("writer is already closed")
|
|
if self.frame_index >= self.max_timesteps:
|
|
raise IndexError("frame index exceeds max_timesteps")
|
|
|
|
qpos = np.asarray(qpos, dtype=np.float32)
|
|
action = np.asarray(action, dtype=np.float32)
|
|
if qpos.shape != (16,):
|
|
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
|
|
if action.shape != (16,):
|
|
raise ValueError(f"action shape must be (16,), got {action.shape}")
|
|
|
|
self._file["observations/qpos"][self.frame_index] = qpos
|
|
self._file["action"][self.frame_index] = action
|
|
|
|
for cam_name in self.camera_names:
|
|
if cam_name not in images:
|
|
raise KeyError(f"missing image for camera '{cam_name}'")
|
|
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
|
|
|
|
self.frame_index += 1
|
|
|
|
def commit(self) -> None:
|
|
if self._closed:
|
|
return
|
|
self._file.flush()
|
|
self._file.close()
|
|
self._closed = True
|
|
os.replace(self.tmp_path, self.dataset_path)
|
|
self._committed = True
|
|
|
|
def discard(self) -> None:
|
|
if not self._closed:
|
|
self._file.close()
|
|
self._closed = True
|
|
if self.tmp_path.exists():
|
|
self.tmp_path.unlink()
|
|
|
|
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
|
image = np.asarray(image, dtype=np.uint8)
|
|
if image.ndim != 3 or image.shape[2] != 3:
|
|
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
|
|
if image.shape[:2] == (self.image_height, self.image_width):
|
|
return image
|
|
|
|
interpolation = cv2.INTER_AREA
|
|
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
|
|
interpolation = cv2.INTER_LINEAR
|
|
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)
|