204 lines
7.5 KiB
Python
204 lines
7.5 KiB
Python
import torch
|
||
import h5py
|
||
from torch.utils.data import Dataset
|
||
from typing import List, Dict, Union
|
||
from pathlib import Path
|
||
|
||
|
||
class SimpleRobotDataset(Dataset):
|
||
"""
|
||
HDF5 懒加载数据集 - LeRobotDataset 格式
|
||
|
||
返回格式:
|
||
- observation.state: (obs_horizon, state_dim)
|
||
- observation.{cam_name}: (obs_horizon, C, H, W)
|
||
- action: (pred_horizon, action_dim)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dataset_dir: Union[str, Path],
|
||
obs_horizon: int = 2,
|
||
pred_horizon: int = 8,
|
||
camera_names: List[str] = None,
|
||
):
|
||
"""
|
||
Args:
|
||
dataset_dir: HDF5 文件目录路径
|
||
obs_horizon: 观察过去多少帧
|
||
pred_horizon: 预测未来多少帧动作
|
||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||
|
||
HDF5 文件格式:
|
||
- action: [T, action_dim]
|
||
- observations/qpos: [T, obs_dim]
|
||
- observations/images/{cam_name}: [T, H, W, C]
|
||
"""
|
||
self.obs_horizon = obs_horizon
|
||
self.pred_horizon = pred_horizon
|
||
self.camera_names = camera_names or []
|
||
|
||
self.dataset_dir = Path(dataset_dir)
|
||
if not self.dataset_dir.exists():
|
||
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
|
||
|
||
# 查找 HDF5 文件
|
||
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
|
||
if not self.hdf5_files:
|
||
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
|
||
if not self.hdf5_files:
|
||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||
|
||
# 构建 episode 索引(只存储元数据,不加载数据)
|
||
self.episodes = {}
|
||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||
with h5py.File(hdf5_path, 'r') as f:
|
||
T = f['action'].shape[0]
|
||
start_idx = len(self.frame_meta)
|
||
for t in range(T):
|
||
self.frame_meta.append({
|
||
"ep_idx": ep_idx,
|
||
"frame_idx": t,
|
||
"hdf5_path": hdf5_path,
|
||
})
|
||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||
|
||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||
|
||
def __len__(self):
|
||
return len(self.frame_meta)
|
||
|
||
def _load_frame(self, idx: int) -> Dict:
|
||
"""从 HDF5 文件懒加载单帧数据"""
|
||
meta = self.frame_meta[idx]
|
||
with h5py.File(meta["hdf5_path"], 'r') as f:
|
||
frame = {
|
||
"episode_index": meta["ep_idx"],
|
||
"frame_index": meta["frame_idx"],
|
||
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
||
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
||
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
||
}
|
||
|
||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||
for cam_name in self.camera_names:
|
||
h5_path = f'observations/images/{cam_name}'
|
||
if h5_path in f:
|
||
img = f[h5_path][meta["frame_idx"]]
|
||
# Resize图像到224x224(减少内存和I/O负担)
|
||
import cv2
|
||
img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
|
||
# 转换为float并归一化到 [0, 1]
|
||
img = torch.from_numpy(img).float() / 255.0
|
||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||
|
||
return frame
|
||
|
||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||
frame = self._load_frame(idx)
|
||
ep_idx = frame["episode_index"]
|
||
|
||
# 获取当前 episode 的帧索引范围
|
||
ep_indices = self.episodes[ep_idx]
|
||
ep_start = ep_indices[0]
|
||
ep_end = ep_indices[-1]
|
||
|
||
# ============================================
|
||
# 1. 加载观察(过去 obs_horizon 帧)
|
||
# ============================================
|
||
observations = {
|
||
"state": [], # 状态数据
|
||
}
|
||
# 为每个摄像头初始化独立列表
|
||
for cam_name in self.camera_names:
|
||
observations[f"observation.{cam_name}"] = []
|
||
|
||
observation_is_pad = []
|
||
|
||
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
||
target_idx = idx + delta
|
||
|
||
# 边界检查
|
||
if ep_start <= target_idx <= ep_end:
|
||
target_frame = self._load_frame(target_idx)
|
||
is_pad = False
|
||
else:
|
||
# 超出边界,用边界帧填充
|
||
if target_idx < ep_start:
|
||
target_frame = self._load_frame(ep_start)
|
||
else:
|
||
target_frame = self._load_frame(ep_end)
|
||
is_pad = True
|
||
|
||
# 收集状态
|
||
observations["state"].append(target_frame["observation.state"])
|
||
|
||
# 收集每个摄像头的图像
|
||
for cam_name in self.camera_names:
|
||
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
|
||
|
||
observation_is_pad.append(is_pad)
|
||
|
||
# ============================================
|
||
# 2. 加载动作(未来 pred_horizon 帧)
|
||
# ============================================
|
||
actions = []
|
||
action_is_pad = []
|
||
|
||
for delta in range(self.pred_horizon):
|
||
target_idx = idx + delta
|
||
|
||
if target_idx <= ep_end:
|
||
actions.append(self._load_frame(target_idx)["action"])
|
||
action_is_pad.append(False)
|
||
else:
|
||
actions.append(self._load_frame(ep_end)["action"])
|
||
action_is_pad.append(True)
|
||
|
||
# ============================================
|
||
# 3. 组装返回数据(LeRobotDataset 格式)
|
||
# ============================================
|
||
result = {
|
||
# 状态观察: (obs_horizon, state_dim)
|
||
"observation.state": torch.stack(observations["state"]),
|
||
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
||
|
||
# 动作: (pred_horizon, action_dim)
|
||
"action": torch.stack(actions),
|
||
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
||
|
||
# 任务
|
||
"task": frame["task"],
|
||
}
|
||
|
||
# 图像:每个摄像头独立的 key
|
||
# 形状: (obs_horizon, C, H, W)
|
||
for cam_name in self.camera_names:
|
||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||
|
||
return result
|
||
|
||
@property
|
||
def camera_keys(self) -> list[str]:
|
||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||
|
||
@property
|
||
def camera_info(self) -> dict:
|
||
"""获取相机信息"""
|
||
if not self.camera_names:
|
||
return {}
|
||
|
||
# 从第一个样本获取形状
|
||
sample = self[0]
|
||
info = {}
|
||
for cam_name in self.camera_names:
|
||
key = f"observation.{cam_name}"
|
||
if key in sample:
|
||
info[key] = {
|
||
"shape": sample[key].shape,
|
||
"dtype": str(sample[key].dtype),
|
||
}
|
||
return info
|