refactor:大重构

This commit is contained in:
gouhanke
2026-02-11 15:53:55 +08:00
parent 1e95d40bf9
commit 130d4bb3c5
19 changed files with 1411 additions and 1223 deletions

View File

@@ -1,152 +0,0 @@
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import h5py
import numpy as np
import os
import glob
import pickle
class RobotDiffusionDataset(Dataset):
def __init__(self,
dataset_dir,
pred_horizon=16,
obs_horizon=2,
action_horizon=8,
camera_names=['r_vis', 'top', 'front'],
normalization_type='gaussian'):
"""
Args:
dataset_dir: 存放 episode_*.hdf5 的文件夹路径
pred_horizon: 预测未来动作的长度 (Tp)
obs_horizon: 历史观测长度 (To)
action_horizon: 执行动作长度 (Ta) - 在Dataset中主要影响Evaluation这里作为参数保留
"""
self.dataset_dir = dataset_dir
self.pred_horizon = pred_horizon
self.obs_horizon = obs_horizon
self.action_horizon = action_horizon
self.camera_names = camera_names
self.normalization_type = normalization_type
# 1. 扫描所有HDF5文件并建立索引
# 格式: [(file_path, episode_length), ...]
self.episode_files = sorted(glob.glob(os.path.join(dataset_dir, 'episode_*.hdf5')))
self.indices = []
print(f"Found {len(self.episode_files)} episodes. Building index...")
for file_path in self.episode_files:
with h5py.File(file_path, 'r') as f:
# 获取该 episode 的长度 (例如 700)
l = f['action'].shape[0]
# 保存每个有效 step 的索引信息
# (file_path, episode_length, current_step_index)
for i in range(l):
self.indices.append((file_path, l, i))
# 2. 统计数据
with open(os.path.join(dataset_dir, 'data_stats.pkl'), 'rb') as f:
self.stats = pickle.load(f)
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
file_path, episode_len, start_ts = self.indices[idx]
# -----------------------------
# 1. 打开文件
# -----------------------------
# 注意: 在 __getitem__ 中打开文件对多进程 DataLoader 更友好
# 如果追求极致IO性能可以考虑使用 h5py 的 swmr 模式或内存缓存
with h5py.File(file_path, 'r') as root:
# -----------------------------
# 2. 处理 Action (Prediction Target)
# -----------------------------
# 目标: 获取 [t, t + pred_horizon] 的动作
action_start = start_ts
action_end = min(start_ts + self.pred_horizon, episode_len)
actions = root['action'][action_start:action_end] # shape: (T_subset, 16)
# Padding: 如果剩余动作不足 pred_horizon复制最后一步
if len(actions) < self.pred_horizon:
pad_len = self.pred_horizon - len(actions)
last_action = actions[-1]
# 重复最后一行
pad_content = np.repeat(last_action[np.newaxis, :], pad_len, axis=0)
actions = np.concatenate([actions, pad_content], axis=0)
# 归一化 Action
if self.stats:
actions = self._normalize_data(actions, self.stats['action'])
# -----------------------------
# 3. 处理 Observations (History)
# -----------------------------
# 目标: 获取 [t - obs_horizon + 1, t + 1] 的观测
# 索引逻辑:
# 如果 obs_horizon=2, current_ts=0 -> indices=[0, 0] (Padding)
# 如果 obs_horizon=2, current_ts=5 -> indices=[4, 5]
start_idx_raw = start_ts - (self.obs_horizon - 1)
start_idx = max(start_idx_raw, 0)
end_idx = start_ts + 1
pad_len = max(0, -start_idx_raw)
# Qpos
qpos_data = root['observations/qpos']
qpos_val = qpos_data[start_idx:end_idx]
if pad_len > 0:
first_frame = qpos_val[0]
padding = np.repeat(first_frame[np.newaxis, :], pad_len, axis=0)
qpos_val = np.concatenate([padding, qpos_val], axis=0)
if self.stats:
qpos_val = self._normalize_data(qpos_val, self.stats['qpos'])
# Images
image_dict = {}
for cam_name in self.camera_names:
img_dset = root['observations']['images'][cam_name]
imgs_np = img_dset[start_idx:end_idx] # (T, H, W, C)
if pad_len > 0:
first_frame = imgs_np[0]
padding = np.repeat(first_frame[np.newaxis, ...], pad_len, axis=0)
imgs_np = np.concatenate([padding, imgs_np], axis=0)
# 转换为 Tensor: (T, H, W, C) -> (T, C, H, W)
imgs_tensor = torch.from_numpy(imgs_np).float() / 255.0
imgs_tensor = torch.einsum('thwc->tchw', imgs_tensor)
image_dict[cam_name] = imgs_tensor
# ==============================
# 3. 组装 Batch
# ==============================
data_batch = {
'action': torch.from_numpy(actions).float(),
'qpos': torch.from_numpy(qpos_val).float(),
}
for cam_name, img_tensor in image_dict.items():
data_batch[f'image_{cam_name}'] = img_tensor
return data_batch
def _normalize_data(self, data, stats):
if self.normalization_type == 'min_max':
# 之前的逻辑: [-1, 1]
min_val = stats['min']
max_val = stats['max']
data = (data - min_val) / (max_val - min_val + 1e-8)
return data * 2 - 1
elif self.normalization_type == 'gaussian':
# 新逻辑: Mean/Std
mean = stats['mean']
std = stats['std']
# (data - mean) / std
# 这里的 data 是 numpy array
return (data - mean) / (std + 1e-8)

View File

@@ -1,523 +1,199 @@
import torch
import h5py
from torch.utils.data import Dataset
from typing import List, Dict, Optional
from typing import List, Dict, Union
from pathlib import Path
class SimpleRobotDataset(Dataset):
"""
LeRobotDataset 简化版 - 图像以字典形式存储
与真实 LeRobotDataset 保持一致:
- Dataset 返回字典,每个摄像头单独的 key
- Policy 负责在 forward 时 stack 图像
HDF5 懒加载数据集 - LeRobotDataset 格式
返回格式:
- observation.state: (obs_horizon, state_dim)
- observation.{cam_name}: (obs_horizon, C, H, W)
- action: (pred_horizon, action_dim)
"""
def __init__(
self,
frames: List[Dict],
dataset_dir: Union[str, Path],
obs_horizon: int = 2,
pred_horizon: int = 8,
image_keys: List[str] = None,
camera_names: List[str] = None,
):
"""
Args:
frames: 帧数据列表。每个元素是一个字典,包含:
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
dataset_dir: HDF5 文件目录路径
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"]
camera_names: 相机名称列表,如 ["r_vis", "top", "front"]
HDF5 文件格式:
- action: [T, action_dim]
- observations/qpos: [T, obs_dim]
- observations/images/{cam_name}: [T, H, W, C]
"""
self.frames = frames
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_keys = image_keys or []
# 构建 episode 索引
self.camera_names = camera_names or []
self.dataset_dir = Path(dataset_dir)
if not self.dataset_dir.exists():
raise FileNotFoundError(f"数据集目录不存在: {dataset_dir}")
# 查找 HDF5 文件
self.hdf5_files = sorted(self.dataset_dir.glob("*.hdf5"))
if not self.hdf5_files:
self.hdf5_files = sorted(self.dataset_dir.glob("episode_*.hdf5"))
if not self.hdf5_files:
raise FileNotFoundError(f"{dataset_dir} 中未找到 HDF5 文件")
# 构建 episode 索引(只存储元数据,不加载数据)
self.episodes = {}
for idx, frame in enumerate(frames):
ep_idx = frame["episode_index"]
if ep_idx not in self.episodes:
self.episodes[ep_idx] = []
self.episodes[ep_idx].append(idx)
self.frame_meta = [] # 存储 (ep_idx, frame_idx, hdf5_path)
for ep_idx, hdf5_path in enumerate(self.hdf5_files):
with h5py.File(hdf5_path, 'r') as f:
T = f['action'].shape[0]
start_idx = len(self.frame_meta)
for t in range(T):
self.frame_meta.append({
"ep_idx": ep_idx,
"frame_idx": t,
"hdf5_path": hdf5_path,
})
self.episodes[ep_idx] = list(range(start_idx, len(self.frame_meta)))
print(f"懒加载模式: {len(self.hdf5_files)} 个 episodes, 共 {len(self.frame_meta)}")
def __len__(self):
return len(self.frames)
return len(self.frame_meta)
def _load_frame(self, idx: int) -> Dict:
"""从 HDF5 文件懒加载单帧数据"""
meta = self.frame_meta[idx]
with h5py.File(meta["hdf5_path"], 'r') as f:
frame = {
"episode_index": meta["ep_idx"],
"frame_index": meta["frame_idx"],
"task": f.get('task', [b"unknown"])[0].decode() if 'task' in f else "unknown",
"observation.state": torch.from_numpy(f['observations/qpos'][meta["frame_idx"]]).float(),
"action": torch.from_numpy(f['action'][meta["frame_idx"]]).float(),
}
# 加载图像数据: observations/images/{cam_name} -> observation.{cam_name}
for cam_name in self.camera_names:
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
img = torch.from_numpy(img).float()
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self.frames[idx]
frame = self._load_frame(idx)
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
ep_indices = self.episodes[ep_idx]
ep_start = ep_indices[0]
ep_end = ep_indices[-1]
# ============================================
# 1. 加载观察(过去 obs_horizon 帧)
# ============================================
observations = {
"state": [], # 状态数据
}
# 为每个摄像头初始化独立列表(字典形式)
for cam_key in self.image_keys:
observations[cam_key] = []
# 为每个摄像头初始化独立列表
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"] = []
observation_is_pad = []
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
target_idx = idx + delta
# 边界检查
if ep_start <= target_idx <= ep_end:
target_frame = self.frames[target_idx]
target_frame = self._load_frame(target_idx)
is_pad = False
else:
# 超出边界,用边界帧填充
if target_idx < ep_start:
target_frame = self.frames[ep_start]
target_frame = self._load_frame(ep_start)
else:
target_frame = self.frames[ep_end]
target_frame = self._load_frame(ep_end)
is_pad = True
# 收集状态
observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像(字典形式,不 stack
for cam_key in self.image_keys:
observations[cam_key].append(target_frame[cam_key])
# 收集每个摄像头的图像
for cam_name in self.camera_names:
observations[f"observation.{cam_name}"].append(target_frame[f"observation.{cam_name}"])
observation_is_pad.append(is_pad)
# ============================================
# 2. 加载动作(未来 pred_horizon 帧)
# ============================================
actions = []
action_is_pad = []
for delta in range(self.pred_horizon):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self.frames[target_idx]["action"])
actions.append(self._load_frame(target_idx)["action"])
action_is_pad.append(False)
else:
actions.append(self.frames[ep_end]["action"])
actions.append(self._load_frame(ep_end)["action"])
action_is_pad.append(True)
# ============================================
# 3. 组装返回数据(字典形式)
# 3. 组装返回数据(LeRobotDataset 格式)
# ============================================
result = {
# 状态观察: (obs_horizon, state_dim)
"observation.state": torch.stack(observations["state"]),
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
# 动作: (pred_horizon, action_dim)
"action": torch.stack(actions),
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
# 任务
"task": frame["task"],
}
# 图像:每个摄像头独立的 key(字典形式)
# 图像:每个摄像头独立的 key
# 形状: (obs_horizon, C, H, W)
for cam_key in self.image_keys:
result[cam_key] = torch.stack(observations[cam_key])
for cam_name in self.camera_names:
result[f"observation.{cam_name}"] = torch.stack(observations[f"observation.{cam_name}"])
return result
@property
def camera_keys(self) -> list[str]:
"""获取所有相机键名"""
return self.image_keys
"""获取所有相机键名 (LeRobotDataset 格式)"""
return [f"observation.{cam_name}" for cam_name in self.camera_names]
@property
def camera_info(self) -> dict:
"""获取相机信息"""
if not self.image_keys:
if not self.camera_names:
return {}
# 从第一个样本获取形状
sample = self[0]
info = {}
for cam_key in self.image_keys:
if cam_key in sample:
info[cam_key] = {
"shape": sample[cam_key].shape,
"dtype": str(sample[cam_key].dtype),
for cam_name in self.camera_names:
key = f"observation.{cam_name}"
if key in sample:
info[key] = {
"shape": sample[key].shape,
"dtype": str(sample[key].dtype),
}
return info
class SimpleDiffusionPolicy(torch.nn.Module):
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
def __init__(
self,
state_dim: int,
action_dim: int,
image_features: Dict[str, tuple] = None,
obs_horizon: int = 2,
pred_horizon: int = 8,
):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_features = image_features or {}
self.state_encoder = torch.nn.Linear(state_dim, 64)
if image_features:
num_cameras = len(image_features)
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
else:
self.fusion = torch.nn.Linear(64, 128)
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""前向传播"""
# 处理状态
state_features = self.state_encoder(batch["observation.state"])
state_features = state_features.mean(dim=1)
# 处理图像(字典形式 → stack
if self.image_features:
image_tensors = [batch[key] for key in self.image_features.keys()]
stacked_images = torch.stack(image_tensors, dim=1)
B, num_cam, T, C, H, W = stacked_images.shape
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
image_features = self.image_encoder(images_flat)
image_features = image_features.mean(dim=[2, 3])
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
image_features = image_features.reshape(B, -1)
features = torch.cat([state_features, image_features], dim=-1)
else:
features = state_features
fused = self.fusion(features)
pred_actions = self.action_head(fused)
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
return pred_actions
def create_demo_data_with_images():
"""创建包含图像的模拟数据"""
frames = []
# Episode 0: pick_up_cube task
for t in range(10):
frames.append({
"episode_index": 0,
"frame_index": t,
"task": "pick_up_cube",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
# Episode 1: stack_blocks task
for t in range(10):
frames.append({
"episode_index": 1,
"frame_index": t,
"task": "stack_blocks",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
return frames
def print_section(title: str):
"""打印分节标题"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def test_dataset_basic_info(dataset):
"""测试数据集基本信息"""
print("\n📊 数据集基本信息:")
print(f" 总帧数: {len(dataset)}")
print(f" 总 episode 数: {len(dataset.episodes)}")
print(f" 观察窗口: {dataset.obs_horizon}")
print(f" 预测窗口: {dataset.pred_horizon}")
print(f"\n📷 相机信息:")
cameras = dataset.camera_keys
print(f" 相机数量: {len(cameras)}")
for cam in cameras:
print(f" - {cam}")
print(f"\n相机详细信息:")
cam_info = dataset.camera_info
for cam, info in cam_info.items():
print(f" {cam}:")
print(f" shape: {info['shape']}")
print(f" dtype: {info['dtype']}")
def test_single_sample(dataset):
"""测试单个样本"""
print_section("1. 测试单个样本")
# Episode 中间的样本
sample = dataset[5]
print("\n样本结构 (字典形式):")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
elif isinstance(value, str):
print(f" {key:30s}: {value}")
# 验证图像是字典形式
print("\n✅ 验证图像存储形式:")
print(" 图像以字典形式存储,每个摄像头独立的 key:")
for cam_key in dataset.camera_keys:
if cam_key in sample:
print(f" - {cam_key}: {sample[cam_key].shape}")
# 验证时间维度
print("\n✅ 验证时间维度:")
print(f" observation.state: {sample['observation.state'].shape}")
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
print(f" action: {sample['action'].shape}")
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
print(" ✓ 时间维度验证通过")
def test_edge_cases(dataset):
"""测试边界情况"""
print_section("2. 测试边界情况")
test_cases = [
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
]
for name, idx, expected in test_cases:
print(f"\n📍 {name} (idx={idx}):")
sample = dataset[idx]
obs_pad = sample["observation_is_pad"].tolist()
action_pad_count = sample["action_is_pad"].sum().item()
print(f" observation_is_pad: {obs_pad}")
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
print(f" action padding 数量: {action_pad_count}")
# 验证观察 padding
if name == "Episode 开头":
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
elif name == "跨 Episode":
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
def test_dataloader(dataset):
"""测试 DataLoader"""
print_section("3. 测试 DataLoader 集成")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0, # 测试时用 0
)
batch = next(iter(dataloader))
print("\n📦 Batch 结构:")
for key in ["observation.state", "observation.image_high_resize",
"observation.image_left_wrist", "action", "task"]:
if key in batch:
value = batch[key]
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
else:
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
print("\n✅ 验证 Batch 形状:")
B = len(batch["observation.state"])
print(f" Batch size: {B}")
# 验证每个摄像头的形状
for cam_key in dataset.camera_keys:
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
actual_shape = batch[cam_key].shape
print(f" {cam_key}:")
print(f" 预期: {expected_shape}")
print(f" 实际: {actual_shape}")
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
print(" ✓ Batch 形状验证通过")
def test_policy_forward(dataset):
"""测试 Policy 前向传播"""
print_section("4. 测试 Policy 前向传播")
# 创建 Policy
policy = SimpleDiffusionPolicy(
state_dim=6,
action_dim=6,
image_features={
"observation.image_high_resize": (3, 64, 64),
"observation.image_left_wrist": (3, 64, 64),
},
obs_horizon=dataset.obs_horizon,
pred_horizon=dataset.pred_horizon,
)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print("\n🔄 Policy.forward() 流程:")
# 1. Stack 之前
print("\n 1⃣ Stack 之前 (字典形式):")
for cam_key in policy.image_features.keys():
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
# 2. 模拟 Stack 操作
print("\n 2⃣ Stack 操作:")
image_tensors = [batch[key] for key in policy.image_features.keys()]
stacked = torch.stack(image_tensors, dim=1)
print(f" stacked_images: {stacked.shape}")
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
# 3. 前向传播
print("\n 3⃣ 前向传播:")
with torch.no_grad():
pred_actions = policy(batch)
print(f" 输入:")
print(f" observation.state: {batch['observation.state'].shape}")
print(f" 图像已 stack")
print(f" 输出:")
print(f" pred_actions: {pred_actions.shape}")
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
print("\n✅ Policy 前向传播验证通过")
def test_data_consistency(dataset):
"""测试数据一致性"""
print_section("5. 测试数据一致性")
print("\n🔍 验证图像 padding 的正确性:")
# Episode 开头的样本
sample = dataset[0]
if sample["observation_is_pad"][0]:
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f" Episode 开头 (idx=0):")
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
print(" ✓ Padding 正确")
# Episode 中间的样本
sample = dataset[5]
if not sample["observation_is_pad"].any():
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f"\n Episode 中间 (idx=5):")
print(f" 没有 padding: {sample['observation_is_pad']}")
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
print(" ✓ 正常帧不重复")
print("\n✅ 数据一致性验证通过")
def test_task_info(dataset):
"""测试任务信息"""
print_section("6. 测试任务信息")
print("\n📋 统计任务分布:")
task_count = {}
for frame in dataset.frames:
task = frame["task"]
task_count[task] = task_count.get(task, 0) + 1
for task, count in task_count.items():
print(f" {task}: {count}")
# 验证 sample 中的 task 信息
sample = dataset[0]
print(f"\n样本 task: {sample['task']}")
print(f" 类型: {type(sample['task'])}")
# 验证 DataLoader 中的 task
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print(f"\nBatch task:")
print(f" 值: {batch['task']}")
print(f" 类型: {type(batch['task'])}")
print(f" 长度: {len(batch['task'])}")
print("\n✅ 任务信息验证通过")
def run_all_tests():
"""运行所有测试"""
print("\n" + "🚀" * 40)
print(" SimpleRobotDataset 完整测试套件")
print("🚀" * 40)
# 创建数据集
print("\n创建测试数据...")
frames = create_demo_data_with_images()
dataset = SimpleRobotDataset(
frames,
obs_horizon=2,
pred_horizon=8,
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
)
print("✓ 数据集创建完成")
# 运行测试
test_dataset_basic_info(dataset)
test_single_sample(dataset)
test_edge_cases(dataset)
test_dataloader(dataset)
test_policy_forward(dataset)
test_data_consistency(dataset)
test_task_info(dataset)
# 总结
print_section("✅ 测试总结")
print("\n所有测试通过!✨")
print("\n关键验证点:")
print(" ✓ 图像以字典形式存储")
print(" ✓ 每个摄像头独立的 key")
print(" ✓ Policy 在 forward 时 stack 图像")
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
print(" ✓ Padding 处理正确")
print(" ✓ DataLoader 集成正确")
print(" ✓ Task 信息传递正确")
print("\n与 LeRobotDataset 设计完全一致!🎉")
if __name__ == "__main__":
from torch.utils.data import DataLoader
run_all_tests()