refactor(dataset): 重新创建robotdataset最小实现

- 内部实现__getitem__参数,可以通过滑动窗口进行采样
-
This commit is contained in:
gouhanke
2026-02-10 10:26:19 +08:00
parent ac870f6110
commit 88b9c10a75

View File

@@ -0,0 +1,523 @@
import torch
from torch.utils.data import Dataset
from typing import List, Dict, Optional
class SimpleRobotDataset(Dataset):
"""
LeRobotDataset 简化版 - 图像以字典形式存储
与真实 LeRobotDataset 保持一致:
- Dataset 返回字典,每个摄像头单独的 key
- Policy 负责在 forward 时 stack 图像
"""
def __init__(
self,
frames: List[Dict],
obs_horizon: int = 2,
pred_horizon: int = 8,
image_keys: List[str] = None,
):
"""
Args:
frames: 帧数据列表。每个元素是一个字典,包含:
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
obs_horizon: 观察过去多少帧
pred_horizon: 预测未来多少帧动作
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"]
"""
self.frames = frames
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_keys = image_keys or []
# 构建 episode 索引
self.episodes = {}
for idx, frame in enumerate(frames):
ep_idx = frame["episode_index"]
if ep_idx not in self.episodes:
self.episodes[ep_idx] = []
self.episodes[ep_idx].append(idx)
def __len__(self):
return len(self.frames)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
frame = self.frames[idx]
ep_idx = frame["episode_index"]
# 获取当前 episode 的帧索引范围
ep_indices = self.episodes[ep_idx]
ep_start = ep_indices[0]
ep_end = ep_indices[-1]
# ============================================
# 1. 加载观察(过去 obs_horizon 帧)
# ============================================
observations = {
"state": [], # 状态数据
}
# 为每个摄像头初始化独立列表(字典形式)
for cam_key in self.image_keys:
observations[cam_key] = []
observation_is_pad = []
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
target_idx = idx + delta
# 边界检查
if ep_start <= target_idx <= ep_end:
target_frame = self.frames[target_idx]
is_pad = False
else:
# 超出边界,用边界帧填充
if target_idx < ep_start:
target_frame = self.frames[ep_start]
else:
target_frame = self.frames[ep_end]
is_pad = True
# 收集状态
observations["state"].append(target_frame["observation.state"])
# 收集每个摄像头的图像(字典形式,不 stack
for cam_key in self.image_keys:
observations[cam_key].append(target_frame[cam_key])
observation_is_pad.append(is_pad)
# ============================================
# 2. 加载动作(未来 pred_horizon 帧)
# ============================================
actions = []
action_is_pad = []
for delta in range(self.pred_horizon):
target_idx = idx + delta
if target_idx <= ep_end:
actions.append(self.frames[target_idx]["action"])
action_is_pad.append(False)
else:
actions.append(self.frames[ep_end]["action"])
action_is_pad.append(True)
# ============================================
# 3. 组装返回数据(字典形式)
# ============================================
result = {
# 状态观察: (obs_horizon, state_dim)
"observation.state": torch.stack(observations["state"]),
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
# 动作: (pred_horizon, action_dim)
"action": torch.stack(actions),
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
# 任务
"task": frame["task"],
}
# 图像:每个摄像头独立的 key字典形式
# 形状: (obs_horizon, C, H, W)
for cam_key in self.image_keys:
result[cam_key] = torch.stack(observations[cam_key])
return result
@property
def camera_keys(self) -> list[str]:
"""获取所有相机键名"""
return self.image_keys
@property
def camera_info(self) -> dict:
"""获取相机信息"""
if not self.image_keys:
return {}
# 从第一个样本获取形状
sample = self[0]
info = {}
for cam_key in self.image_keys:
if cam_key in sample:
info[cam_key] = {
"shape": sample[cam_key].shape,
"dtype": str(sample[cam_key].dtype),
}
return info
class SimpleDiffusionPolicy(torch.nn.Module):
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
def __init__(
self,
state_dim: int,
action_dim: int,
image_features: Dict[str, tuple] = None,
obs_horizon: int = 2,
pred_horizon: int = 8,
):
super().__init__()
self.state_dim = state_dim
self.action_dim = action_dim
self.obs_horizon = obs_horizon
self.pred_horizon = pred_horizon
self.image_features = image_features or {}
self.state_encoder = torch.nn.Linear(state_dim, 64)
if image_features:
num_cameras = len(image_features)
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
else:
self.fusion = torch.nn.Linear(64, 128)
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""前向传播"""
# 处理状态
state_features = self.state_encoder(batch["observation.state"])
state_features = state_features.mean(dim=1)
# 处理图像(字典形式 → stack
if self.image_features:
image_tensors = [batch[key] for key in self.image_features.keys()]
stacked_images = torch.stack(image_tensors, dim=1)
B, num_cam, T, C, H, W = stacked_images.shape
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
image_features = self.image_encoder(images_flat)
image_features = image_features.mean(dim=[2, 3])
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
image_features = image_features.reshape(B, -1)
features = torch.cat([state_features, image_features], dim=-1)
else:
features = state_features
fused = self.fusion(features)
pred_actions = self.action_head(fused)
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
return pred_actions
def create_demo_data_with_images():
"""创建包含图像的模拟数据"""
frames = []
# Episode 0: pick_up_cube task
for t in range(10):
frames.append({
"episode_index": 0,
"frame_index": t,
"task": "pick_up_cube",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
# Episode 1: stack_blocks task
for t in range(10):
frames.append({
"episode_index": 1,
"frame_index": t,
"task": "stack_blocks",
"observation.state": torch.randn(6),
"observation.image_high_resize": torch.randn(3, 64, 64),
"observation.image_left_wrist": torch.randn(3, 64, 64),
"action": torch.randn(6),
})
return frames
def print_section(title: str):
"""打印分节标题"""
print("\n" + "=" * 80)
print(f" {title}")
print("=" * 80)
def test_dataset_basic_info(dataset):
"""测试数据集基本信息"""
print("\n📊 数据集基本信息:")
print(f" 总帧数: {len(dataset)}")
print(f" 总 episode 数: {len(dataset.episodes)}")
print(f" 观察窗口: {dataset.obs_horizon}")
print(f" 预测窗口: {dataset.pred_horizon}")
print(f"\n📷 相机信息:")
cameras = dataset.camera_keys
print(f" 相机数量: {len(cameras)}")
for cam in cameras:
print(f" - {cam}")
print(f"\n相机详细信息:")
cam_info = dataset.camera_info
for cam, info in cam_info.items():
print(f" {cam}:")
print(f" shape: {info['shape']}")
print(f" dtype: {info['dtype']}")
def test_single_sample(dataset):
"""测试单个样本"""
print_section("1. 测试单个样本")
# Episode 中间的样本
sample = dataset[5]
print("\n样本结构 (字典形式):")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
elif isinstance(value, str):
print(f" {key:30s}: {value}")
# 验证图像是字典形式
print("\n✅ 验证图像存储形式:")
print(" 图像以字典形式存储,每个摄像头独立的 key:")
for cam_key in dataset.camera_keys:
if cam_key in sample:
print(f" - {cam_key}: {sample[cam_key].shape}")
# 验证时间维度
print("\n✅ 验证时间维度:")
print(f" observation.state: {sample['observation.state'].shape}")
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
print(f" action: {sample['action'].shape}")
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
print(" ✓ 时间维度验证通过")
def test_edge_cases(dataset):
"""测试边界情况"""
print_section("2. 测试边界情况")
test_cases = [
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
]
for name, idx, expected in test_cases:
print(f"\n📍 {name} (idx={idx}):")
sample = dataset[idx]
obs_pad = sample["observation_is_pad"].tolist()
action_pad_count = sample["action_is_pad"].sum().item()
print(f" observation_is_pad: {obs_pad}")
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
print(f" action padding 数量: {action_pad_count}")
# 验证观察 padding
if name == "Episode 开头":
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
elif name == "跨 Episode":
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
def test_dataloader(dataset):
"""测试 DataLoader"""
print_section("3. 测试 DataLoader 集成")
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
num_workers=0, # 测试时用 0
)
batch = next(iter(dataloader))
print("\n📦 Batch 结构:")
for key in ["observation.state", "observation.image_high_resize",
"observation.image_left_wrist", "action", "task"]:
if key in batch:
value = batch[key]
if isinstance(value, torch.Tensor):
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
else:
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
print("\n✅ 验证 Batch 形状:")
B = len(batch["observation.state"])
print(f" Batch size: {B}")
# 验证每个摄像头的形状
for cam_key in dataset.camera_keys:
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
actual_shape = batch[cam_key].shape
print(f" {cam_key}:")
print(f" 预期: {expected_shape}")
print(f" 实际: {actual_shape}")
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
print(" ✓ Batch 形状验证通过")
def test_policy_forward(dataset):
"""测试 Policy 前向传播"""
print_section("4. 测试 Policy 前向传播")
# 创建 Policy
policy = SimpleDiffusionPolicy(
state_dim=6,
action_dim=6,
image_features={
"observation.image_high_resize": (3, 64, 64),
"observation.image_left_wrist": (3, 64, 64),
},
obs_horizon=dataset.obs_horizon,
pred_horizon=dataset.pred_horizon,
)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print("\n🔄 Policy.forward() 流程:")
# 1. Stack 之前
print("\n 1⃣ Stack 之前 (字典形式):")
for cam_key in policy.image_features.keys():
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
# 2. 模拟 Stack 操作
print("\n 2⃣ Stack 操作:")
image_tensors = [batch[key] for key in policy.image_features.keys()]
stacked = torch.stack(image_tensors, dim=1)
print(f" stacked_images: {stacked.shape}")
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
# 3. 前向传播
print("\n 3⃣ 前向传播:")
with torch.no_grad():
pred_actions = policy(batch)
print(f" 输入:")
print(f" observation.state: {batch['observation.state'].shape}")
print(f" 图像已 stack")
print(f" 输出:")
print(f" pred_actions: {pred_actions.shape}")
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
print("\n✅ Policy 前向传播验证通过")
def test_data_consistency(dataset):
"""测试数据一致性"""
print_section("5. 测试数据一致性")
print("\n🔍 验证图像 padding 的正确性:")
# Episode 开头的样本
sample = dataset[0]
if sample["observation_is_pad"][0]:
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f" Episode 开头 (idx=0):")
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
print(" ✓ Padding 正确")
# Episode 中间的样本
sample = dataset[5]
if not sample["observation_is_pad"].any():
img_0 = sample["observation.image_high_resize"][0]
img_1 = sample["observation.image_high_resize"][1]
print(f"\n Episode 中间 (idx=5):")
print(f" 没有 padding: {sample['observation_is_pad']}")
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
print(" ✓ 正常帧不重复")
print("\n✅ 数据一致性验证通过")
def test_task_info(dataset):
"""测试任务信息"""
print_section("6. 测试任务信息")
print("\n📋 统计任务分布:")
task_count = {}
for frame in dataset.frames:
task = frame["task"]
task_count[task] = task_count.get(task, 0) + 1
for task, count in task_count.items():
print(f" {task}: {count}")
# 验证 sample 中的 task 信息
sample = dataset[0]
print(f"\n样本 task: {sample['task']}")
print(f" 类型: {type(sample['task'])}")
# 验证 DataLoader 中的 task
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
batch = next(iter(dataloader))
print(f"\nBatch task:")
print(f" 值: {batch['task']}")
print(f" 类型: {type(batch['task'])}")
print(f" 长度: {len(batch['task'])}")
print("\n✅ 任务信息验证通过")
def run_all_tests():
"""运行所有测试"""
print("\n" + "🚀" * 40)
print(" SimpleRobotDataset 完整测试套件")
print("🚀" * 40)
# 创建数据集
print("\n创建测试数据...")
frames = create_demo_data_with_images()
dataset = SimpleRobotDataset(
frames,
obs_horizon=2,
pred_horizon=8,
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
)
print("✓ 数据集创建完成")
# 运行测试
test_dataset_basic_info(dataset)
test_single_sample(dataset)
test_edge_cases(dataset)
test_dataloader(dataset)
test_policy_forward(dataset)
test_data_consistency(dataset)
test_task_info(dataset)
# 总结
print_section("✅ 测试总结")
print("\n所有测试通过!✨")
print("\n关键验证点:")
print(" ✓ 图像以字典形式存储")
print(" ✓ 每个摄像头独立的 key")
print(" ✓ Policy 在 forward 时 stack 图像")
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
print(" ✓ Padding 处理正确")
print(" ✓ DataLoader 集成正确")
print(" ✓ Task 信息传递正确")
print("\n与 LeRobotDataset 设计完全一致!🎉")
if __name__ == "__main__":
from torch.utils.data import DataLoader
run_all_tests()