refactor:大重构
This commit is contained in:
@@ -1,152 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
import h5py
|
||||
import numpy as np
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
class RobotDiffusionDataset(Dataset):
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
pred_horizon=16,
|
||||
obs_horizon=2,
|
||||
action_horizon=8,
|
||||
camera_names=['r_vis', 'top', 'front'],
|
||||
normalization_type='gaussian'):
|
||||
"""
|
||||
Args:
|
||||
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
|
||||
pred_horizon: 预测未来动作的长度 (Tp)
|
||||
obs_horizon: 历史观测长度 (To)
|
||||
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation,这里作为参数保留
|
||||
"""
|
||||
self.dataset_dir = dataset_dir
|
||||
self.pred_horizon = pred_horizon
|
||||
self.obs_horizon = obs_horizon
|
||||
self.action_horizon = action_horizon
|
||||
self.camera_names = camera_names
|
||||
self.normalization_type = normalization_type
|
||||
# 1. 扫描所有HDF5文件并建立索引
|
||||
# 格式: [(file_path, episode_length), ...]
|
||||
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
|
||||
self.indices = []
|
||||
|
||||
print(f"Found {len(self.episode_files)} episodes. Building index...")
|
||||
|
||||
for file_path in self.episode_files:
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
# 获取该 episode 的长度 (例如 700)
|
||||
l = f['action'].shape[0]
|
||||
# 保存每个有效 step 的索引信息
|
||||
# (file_path, episode_length, current_step_index)
|
||||
for i in range(l):
|
||||
self.indices.append((file_path, l, i))
|
||||
|
||||
# 2. 统计数据
|
||||
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_path, episode_len, start_ts = self.indices[idx]
|
||||
|
||||
# -----------------------------
|
||||
# 1. 打开文件
|
||||
# -----------------------------
|
||||
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
|
||||
# 如果追求极致IO性能,可以考虑使用 h5py 的 swmr 模式或内存缓存
|
||||
with h5py.File(file_path, 'r') as root:
|
||||
|
||||
# -----------------------------
|
||||
# 2. 处理 Action (Prediction Target)
|
||||
# -----------------------------
|
||||
# 目标: 获取 [t, t + pred_horizon] 的动作
|
||||
action_start = start_ts
|
||||
action_end = min(start_ts + self.pred_horizon, episode_len)
|
||||
|
||||
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
|
||||
|
||||
# Padding: 如果剩余动作不足 pred_horizon,复制最后一步
|
||||
if len(actions) < self.pred_horizon:
|
||||
pad_len = self.pred_horizon - len(actions)
|
||||
last_action = actions[-1]
|
||||
# 重复最后一行
|
||||
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
|
||||
actions = np.concatenate([actions, pad_content], axis=0)
|
||||
|
||||
# 归一化 Action
|
||||
if self.stats:
|
||||
actions = self._normalize_data(actions, self.stats['action'])
|
||||
|
||||
# -----------------------------
|
||||
# 3. 处理 Observations (History)
|
||||
# -----------------------------
|
||||
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
|
||||
# 索引逻辑:
|
||||
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
|
||||
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
|
||||
|
||||
start_idx_raw = start_ts - (self.obs_horizon - 1)
|
||||
start_idx = max(start_idx_raw, 0)
|
||||
end_idx = start_ts + 1
|
||||
pad_len = max(0, -start_idx_raw)
|
||||
|
||||
# Qpos
|
||||
qpos_data = root['observations/qpos']
|
||||
qpos_val = qpos_data[start_idx:end_idx]
|
||||
|
||||
if pad_len > 0:
|
||||
first_frame = qpos_val[0]
|
||||
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
|
||||
qpos_val = np.concatenate([padding, qpos_val], axis=0)
|
||||
|
||||
if self.stats:
|
||||
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
|
||||
|
||||
# Images
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
img_dset = root['observations']['images'][cam_name]
|
||||
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
|
||||
|
||||
if pad_len > 0:
|
||||
first_frame = imgs_np[0]
|
||||
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
|
||||
imgs_np = np.concatenate([padding, imgs_np], axis=0)
|
||||
|
||||
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
|
||||
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
|
||||
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
|
||||
image_dict[cam_name] = imgs_tensor
|
||||
|
||||
# ==============================
|
||||
# 3. 组装 Batch
|
||||
# ==============================
|
||||
data_batch = {
|
||||
'action': torch.from_numpy(actions).float(),
|
||||
'qpos': torch.from_numpy(qpos_val).float(),
|
||||
}
|
||||
for cam_name, img_tensor in image_dict.items():
|
||||
data_batch[f'image_{cam_name}'] = img_tensor
|
||||
|
||||
return data_batch
|
||||
|
||||
def _normalize_data(self, data, stats):
|
||||
if self.normalization_type == 'min_max':
|
||||
# 之前的逻辑: [-1, 1]
|
||||
min_val = stats['min']
|
||||
max_val = stats['max']
|
||||
data = (data - min_val) / (max_val - min_val + 1e-8)
|
||||
return data * 2 - 1
|
||||
|
||||
elif self.normalization_type == 'gaussian':
|
||||
# 新逻辑: Mean/Std
|
||||
mean = stats['mean']
|
||||
std = stats['std']
|
||||
# (data - mean) / std
|
||||
# 这里的 data 是 numpy array
|
||||
return (data - mean) / (std + 1e-8)
|
||||
@@ -1,523 +1,199 @@
|
||||
import torch
|
||||
import h5py
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SimpleRobotDataset(Dataset):
|
||||
"""
|
||||
LeRobotDataset 简化版 - 图像以字典形式存储
|
||||
|
||||
与真实 LeRobotDataset 保持一致:
|
||||
- Dataset 返回字典,每个摄像头单独的 key
|
||||
- Policy 负责在 forward 时 stack 图像
|
||||
HDF5 懒加载数据集 - LeRobotDataset 格式
|
||||
|
||||
返回格式:
|
||||
- observation.state: (obs_horizon, state_dim)
|
||||
- observation.{cam_name}: (obs_horizon, C, H, W)
|
||||
- action: (pred_horizon, action_dim)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frames: List[Dict],
|
||||
dataset_dir: Union[str, Path],
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
image_keys: List[str] = None,
|
||||
camera_names: List[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
frames: 帧数据列表。每个元素是一个字典,包含:
|
||||
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding)。
|
||||
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
|
||||
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
|
||||
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
|
||||
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
|
||||
dataset_dir: HDF5 文件目录路径
|
||||
obs_horizon: 观察过去多少帧
|
||||
pred_horizon: 预测未来多少帧动作
|
||||
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"])
|
||||
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
|
||||
|
||||
HDF5 文件格式:
|
||||
- action: [T, action_dim]
|
||||
- observations/qpos: [T, obs_dim]
|
||||
- observations/images/{cam_name}: [T, H, W, C]
|
||||
"""
|
||||
self.frames = frames
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.image_keys = image_keys or []
|
||||
|
||||
# 构建 episode 索引
|
||||
self.camera_names = camera_names or []
|
||||
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
if not self.dataset_dir.exists():
|
||||
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
|
||||
|
||||
# 查找 HDF5 文件
|
||||
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
|
||||
if not self.hdf5_files:
|
||||
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
|
||||
if not self.hdf5_files:
|
||||
raise FileNotFoundError(f"在 {dataset_dir} 中未找到 HDF5 文件")
|
||||
|
||||
# 构建 episode 索引(只存储元数据,不加载数据)
|
||||
self.episodes = {}
|
||||
for idx, frame in enumerate(frames):
|
||||
ep_idx = frame["episode_index"]
|
||||
if ep_idx not in self.episodes:
|
||||
self.episodes[ep_idx] = []
|
||||
self.episodes[ep_idx].append(idx)
|
||||
|
||||
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
|
||||
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
|
||||
with h5py.File(hdf5_path, 'r') as f:
|
||||
T = f['action'].shape[0]
|
||||
start_idx = len(self.frame_meta)
|
||||
for t in range(T):
|
||||
self.frame_meta.append({
|
||||
"ep_idx": ep_idx,
|
||||
"frame_idx": t,
|
||||
"hdf5_path": hdf5_path,
|
||||
})
|
||||
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
|
||||
|
||||
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)} 帧")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frames)
|
||||
|
||||
return len(self.frame_meta)
|
||||
|
||||
def _load_frame(self, idx: int) -> Dict:
|
||||
"""从 HDF5 文件懒加载单帧数据"""
|
||||
meta = self.frame_meta[idx]
|
||||
with h5py.File(meta["hdf5_path"], 'r') as f:
|
||||
frame = {
|
||||
"episode_index": meta["ep_idx"],
|
||||
"frame_index": meta["frame_idx"],
|
||||
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
|
||||
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
|
||||
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
|
||||
}
|
||||
|
||||
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
|
||||
for cam_name in self.camera_names:
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
img = torch.from_numpy(img).float()
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
return frame
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||
frame = self.frames[idx]
|
||||
frame = self._load_frame(idx)
|
||||
ep_idx = frame["episode_index"]
|
||||
|
||||
|
||||
# 获取当前 episode 的帧索引范围
|
||||
ep_indices = self.episodes[ep_idx]
|
||||
ep_start = ep_indices[0]
|
||||
ep_end = ep_indices[-1]
|
||||
|
||||
|
||||
# ============================================
|
||||
# 1. 加载观察(过去 obs_horizon 帧)
|
||||
# ============================================
|
||||
observations = {
|
||||
"state": [], # 状态数据
|
||||
}
|
||||
# 为每个摄像头初始化独立列表(字典形式)
|
||||
for cam_key in self.image_keys:
|
||||
observations[cam_key] = []
|
||||
|
||||
# 为每个摄像头初始化独立列表
|
||||
for cam_name in self.camera_names:
|
||||
observations[f"observation.{cam_name}"] = []
|
||||
|
||||
observation_is_pad = []
|
||||
|
||||
|
||||
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
||||
target_idx = idx + delta
|
||||
|
||||
|
||||
# 边界检查
|
||||
if ep_start <= target_idx <= ep_end:
|
||||
target_frame = self.frames[target_idx]
|
||||
target_frame = self._load_frame(target_idx)
|
||||
is_pad = False
|
||||
else:
|
||||
# 超出边界,用边界帧填充
|
||||
if target_idx < ep_start:
|
||||
target_frame = self.frames[ep_start]
|
||||
target_frame = self._load_frame(ep_start)
|
||||
else:
|
||||
target_frame = self.frames[ep_end]
|
||||
target_frame = self._load_frame(ep_end)
|
||||
is_pad = True
|
||||
|
||||
|
||||
# 收集状态
|
||||
observations["state"].append(target_frame["observation.state"])
|
||||
|
||||
# 收集每个摄像头的图像(字典形式,不 stack)
|
||||
for cam_key in self.image_keys:
|
||||
observations[cam_key].append(target_frame[cam_key])
|
||||
|
||||
|
||||
# 收集每个摄像头的图像
|
||||
for cam_name in self.camera_names:
|
||||
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
|
||||
|
||||
observation_is_pad.append(is_pad)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 2. 加载动作(未来 pred_horizon 帧)
|
||||
# ============================================
|
||||
actions = []
|
||||
action_is_pad = []
|
||||
|
||||
|
||||
for delta in range(self.pred_horizon):
|
||||
target_idx = idx + delta
|
||||
|
||||
|
||||
if target_idx <= ep_end:
|
||||
actions.append(self.frames[target_idx]["action"])
|
||||
actions.append(self._load_frame(target_idx)["action"])
|
||||
action_is_pad.append(False)
|
||||
else:
|
||||
actions.append(self.frames[ep_end]["action"])
|
||||
actions.append(self._load_frame(ep_end)["action"])
|
||||
action_is_pad.append(True)
|
||||
|
||||
|
||||
# ============================================
|
||||
# 3. 组装返回数据(字典形式)
|
||||
# 3. 组装返回数据(LeRobotDataset 格式)
|
||||
# ============================================
|
||||
result = {
|
||||
# 状态观察: (obs_horizon, state_dim)
|
||||
"observation.state": torch.stack(observations["state"]),
|
||||
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
||||
|
||||
|
||||
# 动作: (pred_horizon, action_dim)
|
||||
"action": torch.stack(actions),
|
||||
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
||||
|
||||
|
||||
# 任务
|
||||
"task": frame["task"],
|
||||
}
|
||||
|
||||
# 图像:每个摄像头独立的 key(字典形式)
|
||||
|
||||
# 图像:每个摄像头独立的 key
|
||||
# 形状: (obs_horizon, C, H, W)
|
||||
for cam_key in self.image_keys:
|
||||
result[cam_key] = torch.stack(observations[cam_key])
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""获取所有相机键名"""
|
||||
return self.image_keys
|
||||
|
||||
"""获取所有相机键名 (LeRobotDataset 格式)"""
|
||||
return [f"observation.{cam_name}" for cam_name in self.camera_names]
|
||||
|
||||
@property
|
||||
def camera_info(self) -> dict:
|
||||
"""获取相机信息"""
|
||||
if not self.image_keys:
|
||||
if not self.camera_names:
|
||||
return {}
|
||||
|
||||
|
||||
# 从第一个样本获取形状
|
||||
sample = self[0]
|
||||
info = {}
|
||||
for cam_key in self.image_keys:
|
||||
if cam_key in sample:
|
||||
info[cam_key] = {
|
||||
"shape": sample[cam_key].shape,
|
||||
"dtype": str(sample[cam_key].dtype),
|
||||
for cam_name in self.camera_names:
|
||||
key = f"observation.{cam_name}"
|
||||
if key in sample:
|
||||
info[key] = {
|
||||
"shape": sample[key].shape,
|
||||
"dtype": str(sample[key].dtype),
|
||||
}
|
||||
return info
|
||||
|
||||
|
||||
class SimpleDiffusionPolicy(torch.nn.Module):
|
||||
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dim: int,
|
||||
action_dim: int,
|
||||
image_features: Dict[str, tuple] = None,
|
||||
obs_horizon: int = 2,
|
||||
pred_horizon: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.state_dim = state_dim
|
||||
self.action_dim = action_dim
|
||||
self.obs_horizon = obs_horizon
|
||||
self.pred_horizon = pred_horizon
|
||||
self.image_features = image_features or {}
|
||||
|
||||
self.state_encoder = torch.nn.Linear(state_dim, 64)
|
||||
if image_features:
|
||||
num_cameras = len(image_features)
|
||||
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
|
||||
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
|
||||
else:
|
||||
self.fusion = torch.nn.Linear(64, 128)
|
||||
|
||||
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
|
||||
|
||||
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""前向传播"""
|
||||
# 处理状态
|
||||
state_features = self.state_encoder(batch["observation.state"])
|
||||
state_features = state_features.mean(dim=1)
|
||||
|
||||
# 处理图像(字典形式 → stack)
|
||||
if self.image_features:
|
||||
image_tensors = [batch[key] for key in self.image_features.keys()]
|
||||
stacked_images = torch.stack(image_tensors, dim=1)
|
||||
|
||||
B, num_cam, T, C, H, W = stacked_images.shape
|
||||
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
|
||||
image_features = self.image_encoder(images_flat)
|
||||
image_features = image_features.mean(dim=[2, 3])
|
||||
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
|
||||
image_features = image_features.reshape(B, -1)
|
||||
|
||||
features = torch.cat([state_features, image_features], dim=-1)
|
||||
else:
|
||||
features = state_features
|
||||
|
||||
fused = self.fusion(features)
|
||||
pred_actions = self.action_head(fused)
|
||||
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
|
||||
|
||||
return pred_actions
|
||||
|
||||
|
||||
def create_demo_data_with_images():
|
||||
"""创建包含图像的模拟数据"""
|
||||
frames = []
|
||||
|
||||
# Episode 0: pick_up_cube task
|
||||
for t in range(10):
|
||||
frames.append({
|
||||
"episode_index": 0,
|
||||
"frame_index": t,
|
||||
"task": "pick_up_cube",
|
||||
"observation.state": torch.randn(6),
|
||||
"observation.image_high_resize": torch.randn(3, 64, 64),
|
||||
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
||||
"action": torch.randn(6),
|
||||
})
|
||||
|
||||
# Episode 1: stack_blocks task
|
||||
for t in range(10):
|
||||
frames.append({
|
||||
"episode_index": 1,
|
||||
"frame_index": t,
|
||||
"task": "stack_blocks",
|
||||
"observation.state": torch.randn(6),
|
||||
"observation.image_high_resize": torch.randn(3, 64, 64),
|
||||
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
||||
"action": torch.randn(6),
|
||||
})
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def print_section(title: str):
|
||||
"""打印分节标题"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f" {title}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def test_dataset_basic_info(dataset):
|
||||
"""测试数据集基本信息"""
|
||||
print("\n📊 数据集基本信息:")
|
||||
print(f" 总帧数: {len(dataset)}")
|
||||
print(f" 总 episode 数: {len(dataset.episodes)}")
|
||||
print(f" 观察窗口: {dataset.obs_horizon}")
|
||||
print(f" 预测窗口: {dataset.pred_horizon}")
|
||||
|
||||
print(f"\n📷 相机信息:")
|
||||
cameras = dataset.camera_keys
|
||||
print(f" 相机数量: {len(cameras)}")
|
||||
for cam in cameras:
|
||||
print(f" - {cam}")
|
||||
|
||||
print(f"\n相机详细信息:")
|
||||
cam_info = dataset.camera_info
|
||||
for cam, info in cam_info.items():
|
||||
print(f" {cam}:")
|
||||
print(f" shape: {info['shape']}")
|
||||
print(f" dtype: {info['dtype']}")
|
||||
|
||||
|
||||
def test_single_sample(dataset):
|
||||
"""测试单个样本"""
|
||||
print_section("1. 测试单个样本")
|
||||
|
||||
# Episode 中间的样本
|
||||
sample = dataset[5]
|
||||
|
||||
print("\n样本结构 (字典形式):")
|
||||
for key, value in sample.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
||||
elif isinstance(value, str):
|
||||
print(f" {key:30s}: {value}")
|
||||
|
||||
# 验证图像是字典形式
|
||||
print("\n✅ 验证图像存储形式:")
|
||||
print(" 图像以字典形式存储,每个摄像头独立的 key:")
|
||||
for cam_key in dataset.camera_keys:
|
||||
if cam_key in sample:
|
||||
print(f" - {cam_key}: {sample[cam_key].shape}")
|
||||
|
||||
# 验证时间维度
|
||||
print("\n✅ 验证时间维度:")
|
||||
print(f" observation.state: {sample['observation.state'].shape}")
|
||||
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
|
||||
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
|
||||
print(f" action: {sample['action'].shape}")
|
||||
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
|
||||
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
|
||||
print(" ✓ 时间维度验证通过")
|
||||
|
||||
|
||||
def test_edge_cases(dataset):
|
||||
"""测试边界情况"""
|
||||
print_section("2. 测试边界情况")
|
||||
|
||||
test_cases = [
|
||||
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
||||
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
|
||||
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
|
||||
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
||||
]
|
||||
|
||||
for name, idx, expected in test_cases:
|
||||
print(f"\n📍 {name} (idx={idx}):")
|
||||
sample = dataset[idx]
|
||||
|
||||
obs_pad = sample["observation_is_pad"].tolist()
|
||||
action_pad_count = sample["action_is_pad"].sum().item()
|
||||
|
||||
print(f" observation_is_pad: {obs_pad}")
|
||||
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
|
||||
print(f" action padding 数量: {action_pad_count}")
|
||||
|
||||
# 验证观察 padding
|
||||
if name == "Episode 开头":
|
||||
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
|
||||
elif name == "跨 Episode":
|
||||
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
|
||||
|
||||
|
||||
def test_dataloader(dataset):
|
||||
"""测试 DataLoader"""
|
||||
print_section("3. 测试 DataLoader 集成")
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
num_workers=0, # 测试时用 0
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
print("\n📦 Batch 结构:")
|
||||
for key in ["observation.state", "observation.image_high_resize",
|
||||
"observation.image_left_wrist", "action", "task"]:
|
||||
if key in batch:
|
||||
value = batch[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
||||
else:
|
||||
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
|
||||
|
||||
print("\n✅ 验证 Batch 形状:")
|
||||
B = len(batch["observation.state"])
|
||||
print(f" Batch size: {B}")
|
||||
|
||||
# 验证每个摄像头的形状
|
||||
for cam_key in dataset.camera_keys:
|
||||
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
|
||||
actual_shape = batch[cam_key].shape
|
||||
print(f" {cam_key}:")
|
||||
print(f" 预期: {expected_shape}")
|
||||
print(f" 实际: {actual_shape}")
|
||||
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
|
||||
print(" ✓ Batch 形状验证通过")
|
||||
|
||||
|
||||
def test_policy_forward(dataset):
|
||||
"""测试 Policy 前向传播"""
|
||||
print_section("4. 测试 Policy 前向传播")
|
||||
|
||||
# 创建 Policy
|
||||
policy = SimpleDiffusionPolicy(
|
||||
state_dim=6,
|
||||
action_dim=6,
|
||||
image_features={
|
||||
"observation.image_high_resize": (3, 64, 64),
|
||||
"observation.image_left_wrist": (3, 64, 64),
|
||||
},
|
||||
obs_horizon=dataset.obs_horizon,
|
||||
pred_horizon=dataset.pred_horizon,
|
||||
)
|
||||
|
||||
# 创建 DataLoader
|
||||
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
print("\n🔄 Policy.forward() 流程:")
|
||||
|
||||
# 1. Stack 之前
|
||||
print("\n 1️⃣ Stack 之前 (字典形式):")
|
||||
for cam_key in policy.image_features.keys():
|
||||
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
|
||||
|
||||
# 2. 模拟 Stack 操作
|
||||
print("\n 2️⃣ Stack 操作:")
|
||||
image_tensors = [batch[key] for key in policy.image_features.keys()]
|
||||
stacked = torch.stack(image_tensors, dim=1)
|
||||
print(f" stacked_images: {stacked.shape}")
|
||||
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
|
||||
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
|
||||
|
||||
# 3. 前向传播
|
||||
print("\n 3️⃣ 前向传播:")
|
||||
with torch.no_grad():
|
||||
pred_actions = policy(batch)
|
||||
|
||||
print(f" 输入:")
|
||||
print(f" observation.state: {batch['observation.state'].shape}")
|
||||
print(f" 图像已 stack")
|
||||
print(f" 输出:")
|
||||
print(f" pred_actions: {pred_actions.shape}")
|
||||
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
|
||||
|
||||
print("\n✅ Policy 前向传播验证通过")
|
||||
|
||||
|
||||
def test_data_consistency(dataset):
|
||||
"""测试数据一致性"""
|
||||
print_section("5. 测试数据一致性")
|
||||
|
||||
print("\n🔍 验证图像 padding 的正确性:")
|
||||
|
||||
# Episode 开头的样本
|
||||
sample = dataset[0]
|
||||
if sample["observation_is_pad"][0]:
|
||||
img_0 = sample["observation.image_high_resize"][0]
|
||||
img_1 = sample["observation.image_high_resize"][1]
|
||||
print(f" Episode 开头 (idx=0):")
|
||||
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
|
||||
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
|
||||
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
|
||||
print(" ✓ Padding 正确")
|
||||
|
||||
# Episode 中间的样本
|
||||
sample = dataset[5]
|
||||
if not sample["observation_is_pad"].any():
|
||||
img_0 = sample["observation.image_high_resize"][0]
|
||||
img_1 = sample["observation.image_high_resize"][1]
|
||||
print(f"\n Episode 中间 (idx=5):")
|
||||
print(f" 没有 padding: {sample['observation_is_pad']}")
|
||||
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
|
||||
print(" ✓ 正常帧不重复")
|
||||
|
||||
print("\n✅ 数据一致性验证通过")
|
||||
|
||||
|
||||
def test_task_info(dataset):
|
||||
"""测试任务信息"""
|
||||
print_section("6. 测试任务信息")
|
||||
|
||||
print("\n📋 统计任务分布:")
|
||||
task_count = {}
|
||||
for frame in dataset.frames:
|
||||
task = frame["task"]
|
||||
task_count[task] = task_count.get(task, 0) + 1
|
||||
|
||||
for task, count in task_count.items():
|
||||
print(f" {task}: {count} 帧")
|
||||
|
||||
# 验证 sample 中的 task 信息
|
||||
sample = dataset[0]
|
||||
print(f"\n样本 task: {sample['task']}")
|
||||
print(f" 类型: {type(sample['task'])}")
|
||||
|
||||
# 验证 DataLoader 中的 task
|
||||
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
||||
batch = next(iter(dataloader))
|
||||
print(f"\nBatch task:")
|
||||
print(f" 值: {batch['task']}")
|
||||
print(f" 类型: {type(batch['task'])}")
|
||||
print(f" 长度: {len(batch['task'])}")
|
||||
|
||||
print("\n✅ 任务信息验证通过")
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "🚀" * 40)
|
||||
print(" SimpleRobotDataset 完整测试套件")
|
||||
print("🚀" * 40)
|
||||
|
||||
# 创建数据集
|
||||
print("\n创建测试数据...")
|
||||
frames = create_demo_data_with_images()
|
||||
dataset = SimpleRobotDataset(
|
||||
frames,
|
||||
obs_horizon=2,
|
||||
pred_horizon=8,
|
||||
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
|
||||
)
|
||||
print("✓ 数据集创建完成")
|
||||
|
||||
# 运行测试
|
||||
test_dataset_basic_info(dataset)
|
||||
test_single_sample(dataset)
|
||||
test_edge_cases(dataset)
|
||||
test_dataloader(dataset)
|
||||
test_policy_forward(dataset)
|
||||
test_data_consistency(dataset)
|
||||
test_task_info(dataset)
|
||||
|
||||
# 总结
|
||||
print_section("✅ 测试总结")
|
||||
print("\n所有测试通过!✨")
|
||||
print("\n关键验证点:")
|
||||
print(" ✓ 图像以字典形式存储")
|
||||
print(" ✓ 每个摄像头独立的 key")
|
||||
print(" ✓ Policy 在 forward 时 stack 图像")
|
||||
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
|
||||
print(" ✓ Padding 处理正确")
|
||||
print(" ✓ DataLoader 集成正确")
|
||||
print(" ✓ Task 信息传递正确")
|
||||
print("\n与 LeRobotDataset 设计完全一致!🎉")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.utils.data import DataLoader
|
||||
run_all_tests()
|
||||
Reference in New Issue
Block a user