feat(data): stream sim episodes with raw ee actions
This commit is contained in:
113
roboimi/utils/streaming_episode_writer.py
Normal file
113
roboimi/utils/streaming_episode_writer.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StreamingEpisodeWriter:
|
||||
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: str | os.PathLike[str],
|
||||
max_timesteps: int,
|
||||
camera_names: list[str],
|
||||
image_size: tuple[int, int] = (256, 256),
|
||||
) -> None:
|
||||
self.dataset_path = Path(dataset_path)
|
||||
self.tmp_path = Path(f"{self.dataset_path}.tmp")
|
||||
self.max_timesteps = int(max_timesteps)
|
||||
self.camera_names = list(camera_names)
|
||||
self.image_height = int(image_size[0])
|
||||
self.image_width = int(image_size[1])
|
||||
self.frame_index = 0
|
||||
self._committed = False
|
||||
self._closed = False
|
||||
|
||||
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
|
||||
self._file.attrs["sim"] = True
|
||||
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
|
||||
self._file.attrs["image_height"] = self.image_height
|
||||
self._file.attrs["image_width"] = self.image_width
|
||||
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
|
||||
|
||||
observations = self._file.create_group("observations")
|
||||
images = observations.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
images.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, self.image_height, self.image_width, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, self.image_height, self.image_width, 3),
|
||||
)
|
||||
observations.create_dataset(
|
||||
"qpos",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
self._file.create_dataset(
|
||||
"action",
|
||||
(self.max_timesteps, 16),
|
||||
dtype="float32",
|
||||
chunks=(min(128, self.max_timesteps), 16),
|
||||
)
|
||||
|
||||
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
|
||||
if self._closed:
|
||||
raise RuntimeError("writer is already closed")
|
||||
if self.frame_index >= self.max_timesteps:
|
||||
raise IndexError("frame index exceeds max_timesteps")
|
||||
|
||||
qpos = np.asarray(qpos, dtype=np.float32)
|
||||
action = np.asarray(action, dtype=np.float32)
|
||||
if qpos.shape != (16,):
|
||||
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
|
||||
if action.shape != (16,):
|
||||
raise ValueError(f"action shape must be (16,), got {action.shape}")
|
||||
|
||||
self._file["observations/qpos"][self.frame_index] = qpos
|
||||
self._file["action"][self.frame_index] = action
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
if cam_name not in images:
|
||||
raise KeyError(f"missing image for camera '{cam_name}'")
|
||||
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
|
||||
|
||||
self.frame_index += 1
|
||||
|
||||
def commit(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._file.flush()
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
os.replace(self.tmp_path, self.dataset_path)
|
||||
self._committed = True
|
||||
|
||||
def discard(self) -> None:
|
||||
if not self._closed:
|
||||
self._file.close()
|
||||
self._closed = True
|
||||
if self.tmp_path.exists():
|
||||
self.tmp_path.unlink()
|
||||
|
||||
def _resize_image(self, image: np.ndarray) -> np.ndarray:
|
||||
image = np.asarray(image, dtype=np.uint8)
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
|
||||
if image.shape[:2] == (self.image_height, self.image_width):
|
||||
return image
|
||||
|
||||
interpolation = cv2.INTER_AREA
|
||||
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
|
||||
interpolation = cv2.INTER_LINEAR
|
||||
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)
|
||||
Reference in New Issue
Block a user