refactor(dataset): 重新创建robotdataset最小实现
- 内部实现__getitem__参数,可以通过滑动窗口进行采样 -
This commit is contained in:
523
roboimi/vla/data/simpe_robot_dataset.py
Normal file
523
roboimi/vla/data/simpe_robot_dataset.py
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
class SimpleRobotDataset(Dataset):
|
||||||
|
"""
|
||||||
|
LeRobotDataset 简化版 - 图像以字典形式存储
|
||||||
|
|
||||||
|
与真实 LeRobotDataset 保持一致:
|
||||||
|
- Dataset 返回字典,每个摄像头单独的 key
|
||||||
|
- Policy 负责在 forward 时 stack 图像
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frames: List[Dict],
|
||||||
|
obs_horizon: int = 2,
|
||||||
|
pred_horizon: int = 8,
|
||||||
|
image_keys: List[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
frames: 帧数据列表。每个元素是一个字典,包含:
|
||||||
|
- "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding)。
|
||||||
|
- "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。
|
||||||
|
- "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。
|
||||||
|
- "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。
|
||||||
|
- "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。
|
||||||
|
obs_horizon: 观察过去多少帧
|
||||||
|
pred_horizon: 预测未来多少帧动作
|
||||||
|
image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"])
|
||||||
|
"""
|
||||||
|
self.frames = frames
|
||||||
|
self.obs_horizon = obs_horizon
|
||||||
|
self.pred_horizon = pred_horizon
|
||||||
|
self.image_keys = image_keys or []
|
||||||
|
|
||||||
|
# 构建 episode 索引
|
||||||
|
self.episodes = {}
|
||||||
|
for idx, frame in enumerate(frames):
|
||||||
|
ep_idx = frame["episode_index"]
|
||||||
|
if ep_idx not in self.episodes:
|
||||||
|
self.episodes[ep_idx] = []
|
||||||
|
self.episodes[ep_idx].append(idx)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.frames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
||||||
|
frame = self.frames[idx]
|
||||||
|
ep_idx = frame["episode_index"]
|
||||||
|
|
||||||
|
# 获取当前 episode 的帧索引范围
|
||||||
|
ep_indices = self.episodes[ep_idx]
|
||||||
|
ep_start = ep_indices[0]
|
||||||
|
ep_end = ep_indices[-1]
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 1. 加载观察(过去 obs_horizon 帧)
|
||||||
|
# ============================================
|
||||||
|
observations = {
|
||||||
|
"state": [], # 状态数据
|
||||||
|
}
|
||||||
|
# 为每个摄像头初始化独立列表(字典形式)
|
||||||
|
for cam_key in self.image_keys:
|
||||||
|
observations[cam_key] = []
|
||||||
|
|
||||||
|
observation_is_pad = []
|
||||||
|
|
||||||
|
for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2
|
||||||
|
target_idx = idx + delta
|
||||||
|
|
||||||
|
# 边界检查
|
||||||
|
if ep_start <= target_idx <= ep_end:
|
||||||
|
target_frame = self.frames[target_idx]
|
||||||
|
is_pad = False
|
||||||
|
else:
|
||||||
|
# 超出边界,用边界帧填充
|
||||||
|
if target_idx < ep_start:
|
||||||
|
target_frame = self.frames[ep_start]
|
||||||
|
else:
|
||||||
|
target_frame = self.frames[ep_end]
|
||||||
|
is_pad = True
|
||||||
|
|
||||||
|
# 收集状态
|
||||||
|
observations["state"].append(target_frame["observation.state"])
|
||||||
|
|
||||||
|
# 收集每个摄像头的图像(字典形式,不 stack)
|
||||||
|
for cam_key in self.image_keys:
|
||||||
|
observations[cam_key].append(target_frame[cam_key])
|
||||||
|
|
||||||
|
observation_is_pad.append(is_pad)
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 2. 加载动作(未来 pred_horizon 帧)
|
||||||
|
# ============================================
|
||||||
|
actions = []
|
||||||
|
action_is_pad = []
|
||||||
|
|
||||||
|
for delta in range(self.pred_horizon):
|
||||||
|
target_idx = idx + delta
|
||||||
|
|
||||||
|
if target_idx <= ep_end:
|
||||||
|
actions.append(self.frames[target_idx]["action"])
|
||||||
|
action_is_pad.append(False)
|
||||||
|
else:
|
||||||
|
actions.append(self.frames[ep_end]["action"])
|
||||||
|
action_is_pad.append(True)
|
||||||
|
|
||||||
|
# ============================================
|
||||||
|
# 3. 组装返回数据(字典形式)
|
||||||
|
# ============================================
|
||||||
|
result = {
|
||||||
|
# 状态观察: (obs_horizon, state_dim)
|
||||||
|
"observation.state": torch.stack(observations["state"]),
|
||||||
|
"observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool),
|
||||||
|
|
||||||
|
# 动作: (pred_horizon, action_dim)
|
||||||
|
"action": torch.stack(actions),
|
||||||
|
"action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool),
|
||||||
|
|
||||||
|
# 任务
|
||||||
|
"task": frame["task"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 图像:每个摄像头独立的 key(字典形式)
|
||||||
|
# 形状: (obs_horizon, C, H, W)
|
||||||
|
for cam_key in self.image_keys:
|
||||||
|
result[cam_key] = torch.stack(observations[cam_key])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""获取所有相机键名"""
|
||||||
|
return self.image_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_info(self) -> dict:
|
||||||
|
"""获取相机信息"""
|
||||||
|
if not self.image_keys:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 从第一个样本获取形状
|
||||||
|
sample = self[0]
|
||||||
|
info = {}
|
||||||
|
for cam_key in self.image_keys:
|
||||||
|
if cam_key in sample:
|
||||||
|
info[cam_key] = {
|
||||||
|
"shape": sample[cam_key].shape,
|
||||||
|
"dtype": str(sample[cam_key].dtype),
|
||||||
|
}
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleDiffusionPolicy(torch.nn.Module):
|
||||||
|
"""简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
state_dim: int,
|
||||||
|
action_dim: int,
|
||||||
|
image_features: Dict[str, tuple] = None,
|
||||||
|
obs_horizon: int = 2,
|
||||||
|
pred_horizon: int = 8,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.action_dim = action_dim
|
||||||
|
self.obs_horizon = obs_horizon
|
||||||
|
self.pred_horizon = pred_horizon
|
||||||
|
self.image_features = image_features or {}
|
||||||
|
|
||||||
|
self.state_encoder = torch.nn.Linear(state_dim, 64)
|
||||||
|
if image_features:
|
||||||
|
num_cameras = len(image_features)
|
||||||
|
self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2)
|
||||||
|
self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128)
|
||||||
|
else:
|
||||||
|
self.fusion = torch.nn.Linear(64, 128)
|
||||||
|
|
||||||
|
self.action_head = torch.nn.Linear(128, action_dim * pred_horizon)
|
||||||
|
|
||||||
|
def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""前向传播"""
|
||||||
|
# 处理状态
|
||||||
|
state_features = self.state_encoder(batch["observation.state"])
|
||||||
|
state_features = state_features.mean(dim=1)
|
||||||
|
|
||||||
|
# 处理图像(字典形式 → stack)
|
||||||
|
if self.image_features:
|
||||||
|
image_tensors = [batch[key] for key in self.image_features.keys()]
|
||||||
|
stacked_images = torch.stack(image_tensors, dim=1)
|
||||||
|
|
||||||
|
B, num_cam, T, C, H, W = stacked_images.shape
|
||||||
|
images_flat = stacked_images.reshape(B * num_cam * T, C, H, W)
|
||||||
|
image_features = self.image_encoder(images_flat)
|
||||||
|
image_features = image_features.mean(dim=[2, 3])
|
||||||
|
image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2)
|
||||||
|
image_features = image_features.reshape(B, -1)
|
||||||
|
|
||||||
|
features = torch.cat([state_features, image_features], dim=-1)
|
||||||
|
else:
|
||||||
|
features = state_features
|
||||||
|
|
||||||
|
fused = self.fusion(features)
|
||||||
|
pred_actions = self.action_head(fused)
|
||||||
|
pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim)
|
||||||
|
|
||||||
|
return pred_actions
|
||||||
|
|
||||||
|
|
||||||
|
def create_demo_data_with_images():
|
||||||
|
"""创建包含图像的模拟数据"""
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
# Episode 0: pick_up_cube task
|
||||||
|
for t in range(10):
|
||||||
|
frames.append({
|
||||||
|
"episode_index": 0,
|
||||||
|
"frame_index": t,
|
||||||
|
"task": "pick_up_cube",
|
||||||
|
"observation.state": torch.randn(6),
|
||||||
|
"observation.image_high_resize": torch.randn(3, 64, 64),
|
||||||
|
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
||||||
|
"action": torch.randn(6),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Episode 1: stack_blocks task
|
||||||
|
for t in range(10):
|
||||||
|
frames.append({
|
||||||
|
"episode_index": 1,
|
||||||
|
"frame_index": t,
|
||||||
|
"task": "stack_blocks",
|
||||||
|
"observation.state": torch.randn(6),
|
||||||
|
"observation.image_high_resize": torch.randn(3, 64, 64),
|
||||||
|
"observation.image_left_wrist": torch.randn(3, 64, 64),
|
||||||
|
"action": torch.randn(6),
|
||||||
|
})
|
||||||
|
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def print_section(title: str):
|
||||||
|
"""打印分节标题"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f" {title}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_basic_info(dataset):
|
||||||
|
"""测试数据集基本信息"""
|
||||||
|
print("\n📊 数据集基本信息:")
|
||||||
|
print(f" 总帧数: {len(dataset)}")
|
||||||
|
print(f" 总 episode 数: {len(dataset.episodes)}")
|
||||||
|
print(f" 观察窗口: {dataset.obs_horizon}")
|
||||||
|
print(f" 预测窗口: {dataset.pred_horizon}")
|
||||||
|
|
||||||
|
print(f"\n📷 相机信息:")
|
||||||
|
cameras = dataset.camera_keys
|
||||||
|
print(f" 相机数量: {len(cameras)}")
|
||||||
|
for cam in cameras:
|
||||||
|
print(f" - {cam}")
|
||||||
|
|
||||||
|
print(f"\n相机详细信息:")
|
||||||
|
cam_info = dataset.camera_info
|
||||||
|
for cam, info in cam_info.items():
|
||||||
|
print(f" {cam}:")
|
||||||
|
print(f" shape: {info['shape']}")
|
||||||
|
print(f" dtype: {info['dtype']}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_sample(dataset):
|
||||||
|
"""测试单个样本"""
|
||||||
|
print_section("1. 测试单个样本")
|
||||||
|
|
||||||
|
# Episode 中间的样本
|
||||||
|
sample = dataset[5]
|
||||||
|
|
||||||
|
print("\n样本结构 (字典形式):")
|
||||||
|
for key, value in sample.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
||||||
|
elif isinstance(value, str):
|
||||||
|
print(f" {key:30s}: {value}")
|
||||||
|
|
||||||
|
# 验证图像是字典形式
|
||||||
|
print("\n✅ 验证图像存储形式:")
|
||||||
|
print(" 图像以字典形式存储,每个摄像头独立的 key:")
|
||||||
|
for cam_key in dataset.camera_keys:
|
||||||
|
if cam_key in sample:
|
||||||
|
print(f" - {cam_key}: {sample[cam_key].shape}")
|
||||||
|
|
||||||
|
# 验证时间维度
|
||||||
|
print("\n✅ 验证时间维度:")
|
||||||
|
print(f" observation.state: {sample['observation.state'].shape}")
|
||||||
|
print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)")
|
||||||
|
assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误"
|
||||||
|
print(f" action: {sample['action'].shape}")
|
||||||
|
print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)")
|
||||||
|
assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误"
|
||||||
|
print(" ✓ 时间维度验证通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_edge_cases(dataset):
|
||||||
|
"""测试边界情况"""
|
||||||
|
print_section("2. 测试边界情况")
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
||||||
|
("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}),
|
||||||
|
("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}),
|
||||||
|
("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}),
|
||||||
|
]
|
||||||
|
|
||||||
|
for name, idx, expected in test_cases:
|
||||||
|
print(f"\n📍 {name} (idx={idx}):")
|
||||||
|
sample = dataset[idx]
|
||||||
|
|
||||||
|
obs_pad = sample["observation_is_pad"].tolist()
|
||||||
|
action_pad_count = sample["action_is_pad"].sum().item()
|
||||||
|
|
||||||
|
print(f" observation_is_pad: {obs_pad}")
|
||||||
|
print(f" action_is_pad: {sample['action_is_pad'].tolist()}")
|
||||||
|
print(f" action padding 数量: {action_pad_count}")
|
||||||
|
|
||||||
|
# 验证观察 padding
|
||||||
|
if name == "Episode 开头":
|
||||||
|
assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding"
|
||||||
|
elif name == "跨 Episode":
|
||||||
|
assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataloader(dataset):
|
||||||
|
"""测试 DataLoader"""
|
||||||
|
print_section("3. 测试 DataLoader 集成")
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=4,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=0, # 测试时用 0
|
||||||
|
)
|
||||||
|
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
print("\n📦 Batch 结构:")
|
||||||
|
for key in ["observation.state", "observation.image_high_resize",
|
||||||
|
"observation.image_left_wrist", "action", "task"]:
|
||||||
|
if key in batch:
|
||||||
|
value = batch[key]
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
print(f" {key:30s}: {str(value.shape):20s} {value.dtype}")
|
||||||
|
else:
|
||||||
|
print(f" {key:30s}: {type(value).__name__} (length={len(value)})")
|
||||||
|
|
||||||
|
print("\n✅ 验证 Batch 形状:")
|
||||||
|
B = len(batch["observation.state"])
|
||||||
|
print(f" Batch size: {B}")
|
||||||
|
|
||||||
|
# 验证每个摄像头的形状
|
||||||
|
for cam_key in dataset.camera_keys:
|
||||||
|
expected_shape = (B, dataset.obs_horizon, 3, 64, 64)
|
||||||
|
actual_shape = batch[cam_key].shape
|
||||||
|
print(f" {cam_key}:")
|
||||||
|
print(f" 预期: {expected_shape}")
|
||||||
|
print(f" 实际: {actual_shape}")
|
||||||
|
assert actual_shape == expected_shape, f"{cam_key} 形状不匹配"
|
||||||
|
print(" ✓ Batch 形状验证通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_policy_forward(dataset):
|
||||||
|
"""测试 Policy 前向传播"""
|
||||||
|
print_section("4. 测试 Policy 前向传播")
|
||||||
|
|
||||||
|
# 创建 Policy
|
||||||
|
policy = SimpleDiffusionPolicy(
|
||||||
|
state_dim=6,
|
||||||
|
action_dim=6,
|
||||||
|
image_features={
|
||||||
|
"observation.image_high_resize": (3, 64, 64),
|
||||||
|
"observation.image_left_wrist": (3, 64, 64),
|
||||||
|
},
|
||||||
|
obs_horizon=dataset.obs_horizon,
|
||||||
|
pred_horizon=dataset.pred_horizon,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 DataLoader
|
||||||
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
|
print("\n🔄 Policy.forward() 流程:")
|
||||||
|
|
||||||
|
# 1. Stack 之前
|
||||||
|
print("\n 1️⃣ Stack 之前 (字典形式):")
|
||||||
|
for cam_key in policy.image_features.keys():
|
||||||
|
print(f" batch['{cam_key}']: {batch[cam_key].shape}")
|
||||||
|
|
||||||
|
# 2. 模拟 Stack 操作
|
||||||
|
print("\n 2️⃣ Stack 操作:")
|
||||||
|
image_tensors = [batch[key] for key in policy.image_features.keys()]
|
||||||
|
stacked = torch.stack(image_tensors, dim=1)
|
||||||
|
print(f" stacked_images: {stacked.shape}")
|
||||||
|
print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ")
|
||||||
|
print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})")
|
||||||
|
|
||||||
|
# 3. 前向传播
|
||||||
|
print("\n 3️⃣ 前向传播:")
|
||||||
|
with torch.no_grad():
|
||||||
|
pred_actions = policy(batch)
|
||||||
|
|
||||||
|
print(f" 输入:")
|
||||||
|
print(f" observation.state: {batch['observation.state'].shape}")
|
||||||
|
print(f" 图像已 stack")
|
||||||
|
print(f" 输出:")
|
||||||
|
print(f" pred_actions: {pred_actions.shape}")
|
||||||
|
print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})")
|
||||||
|
|
||||||
|
print("\n✅ Policy 前向传播验证通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_consistency(dataset):
|
||||||
|
"""测试数据一致性"""
|
||||||
|
print_section("5. 测试数据一致性")
|
||||||
|
|
||||||
|
print("\n🔍 验证图像 padding 的正确性:")
|
||||||
|
|
||||||
|
# Episode 开头的样本
|
||||||
|
sample = dataset[0]
|
||||||
|
if sample["observation_is_pad"][0]:
|
||||||
|
img_0 = sample["observation.image_high_resize"][0]
|
||||||
|
img_1 = sample["observation.image_high_resize"][1]
|
||||||
|
print(f" Episode 开头 (idx=0):")
|
||||||
|
print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}")
|
||||||
|
print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}")
|
||||||
|
assert torch.equal(img_0, img_1), "Padding 应该复制边界帧"
|
||||||
|
print(" ✓ Padding 正确")
|
||||||
|
|
||||||
|
# Episode 中间的样本
|
||||||
|
sample = dataset[5]
|
||||||
|
if not sample["observation_is_pad"].any():
|
||||||
|
img_0 = sample["observation.image_high_resize"][0]
|
||||||
|
img_1 = sample["observation.image_high_resize"][1]
|
||||||
|
print(f"\n Episode 中间 (idx=5):")
|
||||||
|
print(f" 没有 padding: {sample['observation_is_pad']}")
|
||||||
|
print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}")
|
||||||
|
print(" ✓ 正常帧不重复")
|
||||||
|
|
||||||
|
print("\n✅ 数据一致性验证通过")
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_info(dataset):
|
||||||
|
"""测试任务信息"""
|
||||||
|
print_section("6. 测试任务信息")
|
||||||
|
|
||||||
|
print("\n📋 统计任务分布:")
|
||||||
|
task_count = {}
|
||||||
|
for frame in dataset.frames:
|
||||||
|
task = frame["task"]
|
||||||
|
task_count[task] = task_count.get(task, 0) + 1
|
||||||
|
|
||||||
|
for task, count in task_count.items():
|
||||||
|
print(f" {task}: {count} 帧")
|
||||||
|
|
||||||
|
# 验证 sample 中的 task 信息
|
||||||
|
sample = dataset[0]
|
||||||
|
print(f"\n样本 task: {sample['task']}")
|
||||||
|
print(f" 类型: {type(sample['task'])}")
|
||||||
|
|
||||||
|
# 验证 DataLoader 中的 task
|
||||||
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
|
||||||
|
batch = next(iter(dataloader))
|
||||||
|
print(f"\nBatch task:")
|
||||||
|
print(f" 值: {batch['task']}")
|
||||||
|
print(f" 类型: {type(batch['task'])}")
|
||||||
|
print(f" 长度: {len(batch['task'])}")
|
||||||
|
|
||||||
|
print("\n✅ 任务信息验证通过")
|
||||||
|
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""运行所有测试"""
|
||||||
|
print("\n" + "🚀" * 40)
|
||||||
|
print(" SimpleRobotDataset 完整测试套件")
|
||||||
|
print("🚀" * 40)
|
||||||
|
|
||||||
|
# 创建数据集
|
||||||
|
print("\n创建测试数据...")
|
||||||
|
frames = create_demo_data_with_images()
|
||||||
|
dataset = SimpleRobotDataset(
|
||||||
|
frames,
|
||||||
|
obs_horizon=2,
|
||||||
|
pred_horizon=8,
|
||||||
|
image_keys=["observation.image_high_resize", "observation.image_left_wrist"],
|
||||||
|
)
|
||||||
|
print("✓ 数据集创建完成")
|
||||||
|
|
||||||
|
# 运行测试
|
||||||
|
test_dataset_basic_info(dataset)
|
||||||
|
test_single_sample(dataset)
|
||||||
|
test_edge_cases(dataset)
|
||||||
|
test_dataloader(dataset)
|
||||||
|
test_policy_forward(dataset)
|
||||||
|
test_data_consistency(dataset)
|
||||||
|
test_task_info(dataset)
|
||||||
|
|
||||||
|
# 总结
|
||||||
|
print_section("✅ 测试总结")
|
||||||
|
print("\n所有测试通过!✨")
|
||||||
|
print("\n关键验证点:")
|
||||||
|
print(" ✓ 图像以字典形式存储")
|
||||||
|
print(" ✓ 每个摄像头独立的 key")
|
||||||
|
print(" ✓ Policy 在 forward 时 stack 图像")
|
||||||
|
print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)")
|
||||||
|
print(" ✓ Padding 处理正确")
|
||||||
|
print(" ✓ DataLoader 集成正确")
|
||||||
|
print(" ✓ Task 信息传递正确")
|
||||||
|
print("\n与 LeRobotDataset 设计完全一致!🎉")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
run_all_tests()
|
||||||
Reference in New Issue
Block a user