Files
roboimi/roboimi/vla/data/simpe_robot_dataset.py

204 lines
7.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Union
from pathlib import Path
class SimpleRobotDataset(Dataset):
"""
HDF5 懒加载数据集 - LeRobotDataset 格式
返回格式:
- observation.state: (obs_horizon, state_dim)
- observation.{cam_name}: (obs_horizon, C, H, W)
- action: (pred_horizon, action_dim)
"""
def __init__(
self,
dataset_dir: Union[str, Path],
obs_horizon: int = 2,
pred_horizon: int = 8,
camera_names: List[str] = None,
):
"""
Args:
dataset_dir: HDF5 文件目录路径
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
HDF5 文件格式:
- action: [T, action_dim]
- observations/qpos: [T, obs_dim]
- observations/images/{cam_name}: [T, H, W, C]
"""
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.camera_names = camera_names or []
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
# 查找 HDF5 文件
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
if not self.hdf5_files:
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {}
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0]
start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
def __len__(self):
return len(self.frame_meta)
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
# Resize图像到224x224减少内存和I/O负担
import cv2
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self._load_frame(idx)
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
ep_indices = self.episodes[ep_idx]
ep_start = ep_indices[0]
ep_end = ep_indices[-1]
# ============================================
# 1. 加载观察(过去 obs_horizon 帧)
# ============================================
observations = {
"state": [], # 状态数据
}
# 为每个摄像头初始化独立列表
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"] = []
observation_is_pad = []
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
target_idx = idx + delta
# 边界检查
if ep_start <= target_idx <= ep_end:
target_frame = self._load_frame(target_idx)
is_pad = False
else:
# 超出边界,用边界帧填充
if target_idx < ep_start:
target_frame = self._load_frame(ep_start)
else:
target_frame = self._load_frame(ep_end)
is_pad = True
# 收集状态
observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
observation_is_pad.append(is_pad)
# ============================================
# 2. 加载动作(未来 pred_horizon 帧)
# ============================================
actions = []
action_is_pad = []
for delta in range(self.pred_horizon):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self._load_frame(target_idx)["action"])
action_is_pad.append(False)
else:
actions.append(self._load_frame(ep_end)["action"])
action_is_pad.append(True)
# ============================================
# 3. 组装返回数据LeRobotDataset 格式)
# ============================================
result = {
# 状态观察: (obs_horizon, state_dim)
"observation.state": torch.stack(observations["state"]),
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
# 动作: (pred_horizon, action_dim)
"action": torch.stack(actions),
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
# 任务
"task": frame["task"],
}
# 图像:每个摄像头独立的 key
# 形状: (obs_horizon, C, H, W)
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
return result
@property
def camera_keys(self) -> list[str]:
"""获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def camera_info(self) -> dict:
"""获取相机信息"""
if not self.camera_names:
return {}
# 从第一个样本获取形状
sample = self[0]
info = {}
for cam_name in self.camera_names:
key = f"observation.{cam_name}"
if key in sample:
info[key] = {
"shape": sample[key].shape,
"dtype": str(sample[key].dtype),
}
return info