feat(data): stream sim episodes with raw ee actions

This commit is contained in:
Logic
2026-03-31 15:44:53 +08:00
parent d84bc6876e
commit d5d5b53f71
4 changed files with 257 additions and 49 deletions

View File

@@ -0,0 +1,42 @@
# Streaming HDF5 EE Action Dataset Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** 将 Diana 仿真采集改为流式写入 HDF5图像保存为 256x256 的四路相机视角,并把 `/action` 改为 IK 前的原始末端位姿动作。
**Architecture:** 新增一个独立的流式 HDF5 episode writer负责逐帧写入 qpos、原始 action 和 resize 后图像,并在 episode 成功时原子提交、失败时删除临时文件。采集脚本只负责 rollout 和把每一步观测/动作交给 writer避免整集数据先堆在内存里。
**Tech Stack:** Python, h5py, numpy, cv2, unittest, MuJoCo demo scripts
---
### Task 1: 为流式 writer 建立测试边界
**Files:**
- Create: `tests/test_streaming_episode_writer.py`
- Create: `roboimi/utils/streaming_episode_writer.py`
- [ ] **Step 1: Write the failing test**
- [ ] **Step 2: Run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it fails because the writer module does not exist**
- [ ] **Step 3: Implement the minimal streaming writer with temp-file commit/discard, per-frame append, and 256x256 image resize**
- [ ] **Step 4: Re-run `python -m unittest tests.test_streaming_episode_writer -v` and confirm it passes**
### Task 2: 接入 Diana 采集脚本
**Files:**
- Modify: `roboimi/demos/diana_record_sim_episodes.py`
- Reuse: `roboimi/utils/streaming_episode_writer.py`
- [ ] **Step 1: Replace in-memory `data_dict` / `obs` accumulation with per-episode streaming writer lifecycle**
- [ ] **Step 2: Keep four cameras (`angle`, `r_vis`, `top`, `front`) and resize to 256x256 before persistence**
- [ ] **Step 3: Capture raw policy output before IK and write that to `/action`**
- [ ] **Step 4: On success commit to `episode_{idx}.hdf5`; on failure remove temp file**
### Task 3: 验证改动
**Files:**
- Verify only
- [ ] **Step 1: Run unit tests for the writer**
- [ ] **Step 2: Run one end-to-end collection episode and stop after `episode_0.hdf5` becomes readable**
- [ ] **Step 3: Verify HDF5 keys and shapes: `action=(700,16)`, image datasets are `(700,256,256,3)`, and `/action` matches raw EE action semantics**

View File

@@ -1,11 +1,11 @@
import time import time
import os,collections,sys import os
import numpy as np import numpy as np
import h5py
from roboimi.envs.double_pos_ctrl_env import make_sim_env from roboimi.envs.double_pos_ctrl_env import make_sim_env
from diana_policy import TestPickAndTransferPolicy from diana_policy import TestPickAndTransferPolicy
import cv2 import cv2
from roboimi.utils.act_ex_utils import sample_transfer_pose from roboimi.utils.act_ex_utils import sample_transfer_pose
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
import pathlib import pathlib
HOME_PATH = str(pathlib.Path(__file__).parent.resolve()) HOME_PATH = str(pathlib.Path(__file__).parent.resolve())
@@ -16,14 +16,12 @@ def main():
task_name = 'sim_transfer' task_name = 'sim_transfer'
dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir'] dataset_dir = DATASET_DIR + '/sim_transfer' #SIM_TASK_CONFIGS[task_name]['dataset_dir']
num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes'] num_episodes = 100 #SIM_TASK_CONFIGS[task_name]['num_episodes']
onscreen_render = None #config['onscreen_render']
inject_noise = False inject_noise = False
render_cam_name = 'angle'
episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len'] episode_len = 700 #SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names'] camera_names = ['angle','r_vis', 'top', 'front'] #SIM_TASK_CONFIGS[task_name]['camera_names']
image_size = (256, 256)
if task_name == 'sim_transfer': if task_name == 'sim_transfer':
policy = TestPickAndTransferPolicy(inject_noise)
print(task_name) print(task_name)
else: else:
raise NotImplementedError raise NotImplementedError
@@ -39,62 +37,38 @@ def main():
print("osmesa已就绪开始收集数据...") print("osmesa已就绪开始收集数据...")
for episode_idx in range(num_episodes): for episode_idx in range(num_episodes):
obs = [] sum_reward = 0.0
reward_ee = [] max_reward = float('-inf')
print(f'\n{episode_idx=}') print(f'\n{episode_idx=}')
print('Rollout out EE space scripted policy') print('Rollout out EE space scripted policy')
box_pos = sample_transfer_pose() box_pos = sample_transfer_pose()
env.reset(box_pos) env.reset(box_pos)
episode_writer = StreamingEpisodeWriter(
dataset_path=os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5'),
max_timesteps=episode_len,
camera_names=camera_names,
image_size=image_size,
)
for step in range(episode_len): for step in range(episode_len):
raw_action = policy.predict(box_pos,step)
env.step(raw_action)
action = policy.predict(box_pos,step)
env.step(action)
env.render() env.render()
reward_ee.append(env.rew) sum_reward += env.rew
obs.append(env.obs) max_reward = max(max_reward, env.rew)
sum_reward = np.sum(reward_ee) episode_writer.append(
max_reward = np.max(reward_ee) qpos=env.obs['qpos'],
action=raw_action,
images=env.obs['images'],
)
if max_reward == env.max_reward: if max_reward == env.max_reward:
success.append(1) success.append(1)
print(f"{episode_idx=} Successful, {sum_reward=}") print(f"{episode_idx=} Successful, {sum_reward=}")
t0 = time.time() episode_writer.commit()
data_dict = {
'/observations/qpos': [],
'/action': [],
}
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
for i in range(episode_len):
print("type qpos==",obs[i]['qpos'])
data_dict['/observations/qpos'].append(obs[i]['qpos'])
data_dict['/action'].append(obs[i]['action'])
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(obs[i]['images'][cam_name])
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
max_timesteps = episode_len
root.attrs['sim'] = True
obs_ = root.create_group('observations')
image = obs_.create_group('images')
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
qpos = obs_.create_dataset('qpos', (max_timesteps, 16))
action = root.create_dataset('action', (max_timesteps, 16))
for name, array in data_dict.items():
root[name][...] = np.array(array)
else: else:
success.append(0) success.append(0)
print(f"{episode_idx=} Failed") print(f"{episode_idx=} Failed")
print(max_reward) print(max_reward)
del obs episode_writer.discard()
del reward_ee
del sum_reward
del max_reward
# del policy # del policy
# env.viewer.close() # env.viewer.close()
@@ -108,4 +82,4 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -0,0 +1,113 @@
from __future__ import annotations
import os
from pathlib import Path
import cv2
import h5py
import numpy as np
class StreamingEpisodeWriter:
"""逐帧写入 episode 数据,成功后提交,失败时丢弃临时文件。"""
def __init__(
self,
dataset_path: str | os.PathLike[str],
max_timesteps: int,
camera_names: list[str],
image_size: tuple[int, int] = (256, 256),
) -> None:
self.dataset_path = Path(dataset_path)
self.tmp_path = Path(f"{self.dataset_path}.tmp")
self.max_timesteps = int(max_timesteps)
self.camera_names = list(camera_names)
self.image_height = int(image_size[0])
self.image_width = int(image_size[1])
self.frame_index = 0
self._committed = False
self._closed = False
self.dataset_path.parent.mkdir(parents=True, exist_ok=True)
if self.tmp_path.exists():
self.tmp_path.unlink()
self._file = h5py.File(self.tmp_path, "w", rdcc_nbytes=1024**2 * 2)
self._file.attrs["sim"] = True
self._file.attrs["action_repr"] = "ee_pose_xyz_quat_gripper"
self._file.attrs["image_height"] = self.image_height
self._file.attrs["image_width"] = self.image_width
self._file.attrs["camera_names"] = np.asarray(self.camera_names, dtype="S")
observations = self._file.create_group("observations")
images = observations.create_group("images")
for cam_name in self.camera_names:
images.create_dataset(
cam_name,
(self.max_timesteps, self.image_height, self.image_width, 3),
dtype="uint8",
chunks=(1, self.image_height, self.image_width, 3),
)
observations.create_dataset(
"qpos",
(self.max_timesteps, 16),
dtype="float32",
chunks=(min(128, self.max_timesteps), 16),
)
self._file.create_dataset(
"action",
(self.max_timesteps, 16),
dtype="float32",
chunks=(min(128, self.max_timesteps), 16),
)
def append(self, qpos: np.ndarray, action: np.ndarray, images: dict[str, np.ndarray]) -> None:
if self._closed:
raise RuntimeError("writer is already closed")
if self.frame_index >= self.max_timesteps:
raise IndexError("frame index exceeds max_timesteps")
qpos = np.asarray(qpos, dtype=np.float32)
action = np.asarray(action, dtype=np.float32)
if qpos.shape != (16,):
raise ValueError(f"qpos shape must be (16,), got {qpos.shape}")
if action.shape != (16,):
raise ValueError(f"action shape must be (16,), got {action.shape}")
self._file["observations/qpos"][self.frame_index] = qpos
self._file["action"][self.frame_index] = action
for cam_name in self.camera_names:
if cam_name not in images:
raise KeyError(f"missing image for camera '{cam_name}'")
self._file[f"observations/images/{cam_name}"][self.frame_index] = self._resize_image(images[cam_name])
self.frame_index += 1
def commit(self) -> None:
if self._closed:
return
self._file.flush()
self._file.close()
self._closed = True
os.replace(self.tmp_path, self.dataset_path)
self._committed = True
def discard(self) -> None:
if not self._closed:
self._file.close()
self._closed = True
if self.tmp_path.exists():
self.tmp_path.unlink()
def _resize_image(self, image: np.ndarray) -> np.ndarray:
image = np.asarray(image, dtype=np.uint8)
if image.ndim != 3 or image.shape[2] != 3:
raise ValueError(f"image shape must be HxWx3, got {image.shape}")
if image.shape[:2] == (self.image_height, self.image_width):
return image
interpolation = cv2.INTER_AREA
if image.shape[0] < self.image_height or image.shape[1] < self.image_width:
interpolation = cv2.INTER_LINEAR
return cv2.resize(image, (self.image_width, self.image_height), interpolation=interpolation)

View File

@@ -0,0 +1,79 @@
import tempfile
import unittest
from pathlib import Path
import h5py
import numpy as np
from roboimi.utils.streaming_episode_writer import StreamingEpisodeWriter
class StreamingEpisodeWriterTest(unittest.TestCase):
def test_commit_persists_raw_action_and_resized_images(self):
camera_names = ["angle", "r_vis", "top", "front"]
raw_action_0 = np.arange(16, dtype=np.float32)
raw_action_1 = np.arange(16, dtype=np.float32) + 100.0
qpos_0 = np.arange(16, dtype=np.float32) + 200.0
qpos_1 = np.arange(16, dtype=np.float32) + 300.0
with tempfile.TemporaryDirectory() as tmpdir:
episode_path = Path(tmpdir) / "episode_0.hdf5"
writer = StreamingEpisodeWriter(
dataset_path=episode_path,
max_timesteps=2,
camera_names=camera_names,
image_size=(256, 256),
)
writer.append(
qpos=qpos_0,
action=raw_action_0,
images={
cam: np.full((480, 640, 3), fill_value=idx + 1, dtype=np.uint8)
for idx, cam in enumerate(camera_names)
},
)
writer.append(
qpos=qpos_1,
action=raw_action_1,
images={
cam: np.full((480, 640, 3), fill_value=idx + 11, dtype=np.uint8)
for idx, cam in enumerate(camera_names)
},
)
writer.commit()
self.assertTrue(episode_path.exists())
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
with h5py.File(episode_path, "r") as root:
self.assertEqual(root["action"].shape, (2, 16))
self.assertEqual(root["observations/qpos"].shape, (2, 16))
np.testing.assert_allclose(root["action"][0], raw_action_0)
np.testing.assert_allclose(root["action"][1], raw_action_1)
np.testing.assert_allclose(root["observations/qpos"][0], qpos_0)
np.testing.assert_allclose(root["observations/qpos"][1], qpos_1)
for idx, cam_name in enumerate(camera_names):
dataset = root[f"observations/images/{cam_name}"]
self.assertEqual(dataset.shape, (2, 256, 256, 3))
self.assertEqual(dataset.dtype, np.uint8)
self.assertTrue(np.all(dataset[0] == idx + 1))
self.assertTrue(np.all(dataset[1] == idx + 11))
def test_discard_removes_temporary_file(self):
with tempfile.TemporaryDirectory() as tmpdir:
episode_path = Path(tmpdir) / "episode_0.hdf5"
writer = StreamingEpisodeWriter(
dataset_path=episode_path,
max_timesteps=1,
camera_names=["angle", "r_vis", "top", "front"],
image_size=(256, 256),
)
writer.discard()
self.assertFalse(episode_path.exists())
self.assertFalse(Path(str(episode_path) + ".tmp").exists())
if __name__ == "__main__":
unittest.main()