diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py new file mode 100644 index 0000000..04d05f0 --- /dev/null +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -0,0 +1,523 @@ +import torch +from torch.utils.data import Dataset +from typing import List, Dict, Optional + +class SimpleRobotDataset(Dataset): + """ + LeRobotDataset 简化版 - 图像以字典形式存储 + + 与真实 LeRobotDataset 保持一致: + - Dataset 返回字典,每个摄像头单独的 key + - Policy 负责在 forward 时 stack 图像 + """ + + def __init__( + self, + frames: List[Dict], + obs_horizon: int = 2, + pred_horizon: int = 8, + image_keys: List[str] = None, + ): + """ + Args: + frames: 帧数据列表。每个元素是一个字典,包含: + - "episode_index" (int): [必须] 该帧所属的 Episode ID。Dataset 使用它来确定 Episode 的边界(用于 Padding)。 + - "task" (str): [必须] 任务描述字符串(例如 "pick_up_cube")。 + - "observation.state" (torch.Tensor): (state_dim,) [必须] 当前帧的机器人状态向量(例如关节角度)。 + - "action" (torch.Tensor): (action_dim,) [必须] 当前帧对应的动作向量。 + - "{image_key}" (torch.Tensor): (C, H, W) [可选] 当前帧的图像数据。键名必须与初始化 Dataset 时传入的 image_keys 列表一致。 + obs_horizon: 观察过去多少帧 + pred_horizon: 预测未来多少帧动作 + image_keys: 哪些 key 是图像数据(例如 ["observation.image_0", "observation.image_1"]) + """ + self.frames = frames + self.obs_horizon = obs_horizon + self.pred_horizon = pred_horizon + self.image_keys = image_keys or [] + + # 构建 episode 索引 + self.episodes = {} + for idx, frame in enumerate(frames): + ep_idx = frame["episode_index"] + if ep_idx not in self.episodes: + self.episodes[ep_idx] = [] + self.episodes[ep_idx].append(idx) + + def __len__(self): + return len(self.frames) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + frame = self.frames[idx] + ep_idx = frame["episode_index"] + + # 获取当前 episode 的帧索引范围 + ep_indices = self.episodes[ep_idx] + ep_start = ep_indices[0] + ep_end = ep_indices[-1] + + # ============================================ + # 1. 加载观察(过去 obs_horizon 帧) + # ============================================ + observations = { + "state": [], # 状态数据 + } + # 为每个摄像头初始化独立列表(字典形式) + for cam_key in self.image_keys: + observations[cam_key] = [] + + observation_is_pad = [] + + for delta in range(-self.obs_horizon + 1, 1): # [-1, 0] for obs_horizon=2 + target_idx = idx + delta + + # 边界检查 + if ep_start <= target_idx <= ep_end: + target_frame = self.frames[target_idx] + is_pad = False + else: + # 超出边界,用边界帧填充 + if target_idx < ep_start: + target_frame = self.frames[ep_start] + else: + target_frame = self.frames[ep_end] + is_pad = True + + # 收集状态 + observations["state"].append(target_frame["observation.state"]) + + # 收集每个摄像头的图像(字典形式,不 stack) + for cam_key in self.image_keys: + observations[cam_key].append(target_frame[cam_key]) + + observation_is_pad.append(is_pad) + + # ============================================ + # 2. 加载动作(未来 pred_horizon 帧) + # ============================================ + actions = [] + action_is_pad = [] + + for delta in range(self.pred_horizon): + target_idx = idx + delta + + if target_idx <= ep_end: + actions.append(self.frames[target_idx]["action"]) + action_is_pad.append(False) + else: + actions.append(self.frames[ep_end]["action"]) + action_is_pad.append(True) + + # ============================================ + # 3. 组装返回数据(字典形式) + # ============================================ + result = { + # 状态观察: (obs_horizon, state_dim) + "observation.state": torch.stack(observations["state"]), + "observation_is_pad": torch.tensor(observation_is_pad, dtype=torch.bool), + + # 动作: (pred_horizon, action_dim) + "action": torch.stack(actions), + "action_is_pad": torch.tensor(action_is_pad, dtype=torch.bool), + + # 任务 + "task": frame["task"], + } + + # 图像:每个摄像头独立的 key(字典形式) + # 形状: (obs_horizon, C, H, W) + for cam_key in self.image_keys: + result[cam_key] = torch.stack(observations[cam_key]) + + return result + + @property + def camera_keys(self) -> list[str]: + """获取所有相机键名""" + return self.image_keys + + @property + def camera_info(self) -> dict: + """获取相机信息""" + if not self.image_keys: + return {} + + # 从第一个样本获取形状 + sample = self[0] + info = {} + for cam_key in self.image_keys: + if cam_key in sample: + info[cam_key] = { + "shape": sample[cam_key].shape, + "dtype": str(sample[cam_key].dtype), + } + return info + + +class SimpleDiffusionPolicy(torch.nn.Module): + """简化的 Diffusion Policy - 展示如何在 forward 时 stack 图像""" + + def __init__( + self, + state_dim: int, + action_dim: int, + image_features: Dict[str, tuple] = None, + obs_horizon: int = 2, + pred_horizon: int = 8, + ): + super().__init__() + self.state_dim = state_dim + self.action_dim = action_dim + self.obs_horizon = obs_horizon + self.pred_horizon = pred_horizon + self.image_features = image_features or {} + + self.state_encoder = torch.nn.Linear(state_dim, 64) + if image_features: + num_cameras = len(image_features) + self.image_encoder = torch.nn.Conv2d(3, 32, kernel_size=7, stride=2) + self.fusion = torch.nn.Linear(64 + 32 * num_cameras, 128) + else: + self.fusion = torch.nn.Linear(64, 128) + + self.action_head = torch.nn.Linear(128, action_dim * pred_horizon) + + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + """前向传播""" + # 处理状态 + state_features = self.state_encoder(batch["observation.state"]) + state_features = state_features.mean(dim=1) + + # 处理图像(字典形式 → stack) + if self.image_features: + image_tensors = [batch[key] for key in self.image_features.keys()] + stacked_images = torch.stack(image_tensors, dim=1) + + B, num_cam, T, C, H, W = stacked_images.shape + images_flat = stacked_images.reshape(B * num_cam * T, C, H, W) + image_features = self.image_encoder(images_flat) + image_features = image_features.mean(dim=[2, 3]) + image_features = image_features.reshape(B, num_cam, T, 32).mean(dim=2) + image_features = image_features.reshape(B, -1) + + features = torch.cat([state_features, image_features], dim=-1) + else: + features = state_features + + fused = self.fusion(features) + pred_actions = self.action_head(fused) + pred_actions = pred_actions.reshape(B, self.pred_horizon, self.action_dim) + + return pred_actions + + +def create_demo_data_with_images(): + """创建包含图像的模拟数据""" + frames = [] + + # Episode 0: pick_up_cube task + for t in range(10): + frames.append({ + "episode_index": 0, + "frame_index": t, + "task": "pick_up_cube", + "observation.state": torch.randn(6), + "observation.image_high_resize": torch.randn(3, 64, 64), + "observation.image_left_wrist": torch.randn(3, 64, 64), + "action": torch.randn(6), + }) + + # Episode 1: stack_blocks task + for t in range(10): + frames.append({ + "episode_index": 1, + "frame_index": t, + "task": "stack_blocks", + "observation.state": torch.randn(6), + "observation.image_high_resize": torch.randn(3, 64, 64), + "observation.image_left_wrist": torch.randn(3, 64, 64), + "action": torch.randn(6), + }) + + return frames + + +def print_section(title: str): + """打印分节标题""" + print("\n" + "=" * 80) + print(f" {title}") + print("=" * 80) + + +def test_dataset_basic_info(dataset): + """测试数据集基本信息""" + print("\n📊 数据集基本信息:") + print(f" 总帧数: {len(dataset)}") + print(f" 总 episode 数: {len(dataset.episodes)}") + print(f" 观察窗口: {dataset.obs_horizon}") + print(f" 预测窗口: {dataset.pred_horizon}") + + print(f"\n📷 相机信息:") + cameras = dataset.camera_keys + print(f" 相机数量: {len(cameras)}") + for cam in cameras: + print(f" - {cam}") + + print(f"\n相机详细信息:") + cam_info = dataset.camera_info + for cam, info in cam_info.items(): + print(f" {cam}:") + print(f" shape: {info['shape']}") + print(f" dtype: {info['dtype']}") + + +def test_single_sample(dataset): + """测试单个样本""" + print_section("1. 测试单个样本") + + # Episode 中间的样本 + sample = dataset[5] + + print("\n样本结构 (字典形式):") + for key, value in sample.items(): + if isinstance(value, torch.Tensor): + print(f" {key:30s}: {str(value.shape):20s} {value.dtype}") + elif isinstance(value, str): + print(f" {key:30s}: {value}") + + # 验证图像是字典形式 + print("\n✅ 验证图像存储形式:") + print(" 图像以字典形式存储,每个摄像头独立的 key:") + for cam_key in dataset.camera_keys: + if cam_key in sample: + print(f" - {cam_key}: {sample[cam_key].shape}") + + # 验证时间维度 + print("\n✅ 验证时间维度:") + print(f" observation.state: {sample['observation.state'].shape}") + print(f" 预期: (obs_horizon={dataset.obs_horizon}, state_dim=6)") + assert sample['observation.state'].shape[0] == dataset.obs_horizon, "观察时间维度错误" + print(f" action: {sample['action'].shape}") + print(f" 预期: (pred_horizon={dataset.pred_horizon}, action_dim=6)") + assert sample['action'].shape[0] == dataset.pred_horizon, "动作时间维度错误" + print(" ✓ 时间维度验证通过") + + +def test_edge_cases(dataset): + """测试边界情况""" + print_section("2. 测试边界情况") + + test_cases = [ + ("Episode 开头", 0, {"obs_pad": [True, False], "action_pad": [False] * 8}), + ("Episode 中间", 5, {"obs_pad": [False, False], "action_pad": [False] * 5 + [True] * 3}), + ("Episode 末尾", 9, {"obs_pad": [False, False], "action_pad": [True] * 8}), + ("跨 Episode", 10, {"obs_pad": [True, False], "action_pad": [False] * 8}), + ] + + for name, idx, expected in test_cases: + print(f"\n📍 {name} (idx={idx}):") + sample = dataset[idx] + + obs_pad = sample["observation_is_pad"].tolist() + action_pad_count = sample["action_is_pad"].sum().item() + + print(f" observation_is_pad: {obs_pad}") + print(f" action_is_pad: {sample['action_is_pad'].tolist()}") + print(f" action padding 数量: {action_pad_count}") + + # 验证观察 padding + if name == "Episode 开头": + assert obs_pad[0] == True, "Episode 开头第一帧应该是 padding" + elif name == "跨 Episode": + assert obs_pad[0] == True, "跨 Episode 第一帧应该是 padding" + + +def test_dataloader(dataset): + """测试 DataLoader""" + print_section("3. 测试 DataLoader 集成") + + dataloader = DataLoader( + dataset, + batch_size=4, + shuffle=True, + num_workers=0, # 测试时用 0 + ) + + batch = next(iter(dataloader)) + + print("\n📦 Batch 结构:") + for key in ["observation.state", "observation.image_high_resize", + "observation.image_left_wrist", "action", "task"]: + if key in batch: + value = batch[key] + if isinstance(value, torch.Tensor): + print(f" {key:30s}: {str(value.shape):20s} {value.dtype}") + else: + print(f" {key:30s}: {type(value).__name__} (length={len(value)})") + + print("\n✅ 验证 Batch 形状:") + B = len(batch["observation.state"]) + print(f" Batch size: {B}") + + # 验证每个摄像头的形状 + for cam_key in dataset.camera_keys: + expected_shape = (B, dataset.obs_horizon, 3, 64, 64) + actual_shape = batch[cam_key].shape + print(f" {cam_key}:") + print(f" 预期: {expected_shape}") + print(f" 实际: {actual_shape}") + assert actual_shape == expected_shape, f"{cam_key} 形状不匹配" + print(" ✓ Batch 形状验证通过") + + +def test_policy_forward(dataset): + """测试 Policy 前向传播""" + print_section("4. 测试 Policy 前向传播") + + # 创建 Policy + policy = SimpleDiffusionPolicy( + state_dim=6, + action_dim=6, + image_features={ + "observation.image_high_resize": (3, 64, 64), + "observation.image_left_wrist": (3, 64, 64), + }, + obs_horizon=dataset.obs_horizon, + pred_horizon=dataset.pred_horizon, + ) + + # 创建 DataLoader + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + batch = next(iter(dataloader)) + + print("\n🔄 Policy.forward() 流程:") + + # 1. Stack 之前 + print("\n 1️⃣ Stack 之前 (字典形式):") + for cam_key in policy.image_features.keys(): + print(f" batch['{cam_key}']: {batch[cam_key].shape}") + + # 2. 模拟 Stack 操作 + print("\n 2️⃣ Stack 操作:") + image_tensors = [batch[key] for key in policy.image_features.keys()] + stacked = torch.stack(image_tensors, dim=1) + print(f" stacked_images: {stacked.shape}") + print(f" (B={stacked.shape[0]}, num_cam={stacked.shape[1]}, ") + print(f" obs_hor={stacked.shape[2]}, C={stacked.shape[3]}, H={stacked.shape[4]}, W={stacked.shape[5]})") + + # 3. 前向传播 + print("\n 3️⃣ 前向传播:") + with torch.no_grad(): + pred_actions = policy(batch) + + print(f" 输入:") + print(f" observation.state: {batch['observation.state'].shape}") + print(f" 图像已 stack") + print(f" 输出:") + print(f" pred_actions: {pred_actions.shape}") + print(f" (B={pred_actions.shape[0]}, pred_horizon={pred_actions.shape[1]}, action_dim={pred_actions.shape[2]})") + + print("\n✅ Policy 前向传播验证通过") + + +def test_data_consistency(dataset): + """测试数据一致性""" + print_section("5. 测试数据一致性") + + print("\n🔍 验证图像 padding 的正确性:") + + # Episode 开头的样本 + sample = dataset[0] + if sample["observation_is_pad"][0]: + img_0 = sample["observation.image_high_resize"][0] + img_1 = sample["observation.image_high_resize"][1] + print(f" Episode 开头 (idx=0):") + print(f" 第0帧是 padding: {sample['observation_is_pad'][0]}") + print(f" 第0帧图像 = 第1帧图像: {torch.equal(img_0, img_1)}") + assert torch.equal(img_0, img_1), "Padding 应该复制边界帧" + print(" ✓ Padding 正确") + + # Episode 中间的样本 + sample = dataset[5] + if not sample["observation_is_pad"].any(): + img_0 = sample["observation.image_high_resize"][0] + img_1 = sample["observation.image_high_resize"][1] + print(f"\n Episode 中间 (idx=5):") + print(f" 没有 padding: {sample['observation_is_pad']}") + print(f" 第0帧图像 ≠ 第1帧图像: {not torch.equal(img_0, img_1)}") + print(" ✓ 正常帧不重复") + + print("\n✅ 数据一致性验证通过") + + +def test_task_info(dataset): + """测试任务信息""" + print_section("6. 测试任务信息") + + print("\n📋 统计任务分布:") + task_count = {} + for frame in dataset.frames: + task = frame["task"] + task_count[task] = task_count.get(task, 0) + 1 + + for task, count in task_count.items(): + print(f" {task}: {count} 帧") + + # 验证 sample 中的 task 信息 + sample = dataset[0] + print(f"\n样本 task: {sample['task']}") + print(f" 类型: {type(sample['task'])}") + + # 验证 DataLoader 中的 task + dataloader = DataLoader(dataset, batch_size=4, shuffle=False) + batch = next(iter(dataloader)) + print(f"\nBatch task:") + print(f" 值: {batch['task']}") + print(f" 类型: {type(batch['task'])}") + print(f" 长度: {len(batch['task'])}") + + print("\n✅ 任务信息验证通过") + + +def run_all_tests(): + """运行所有测试""" + print("\n" + "🚀" * 40) + print(" SimpleRobotDataset 完整测试套件") + print("🚀" * 40) + + # 创建数据集 + print("\n创建测试数据...") + frames = create_demo_data_with_images() + dataset = SimpleRobotDataset( + frames, + obs_horizon=2, + pred_horizon=8, + image_keys=["observation.image_high_resize", "observation.image_left_wrist"], + ) + print("✓ 数据集创建完成") + + # 运行测试 + test_dataset_basic_info(dataset) + test_single_sample(dataset) + test_edge_cases(dataset) + test_dataloader(dataset) + test_policy_forward(dataset) + test_data_consistency(dataset) + test_task_info(dataset) + + # 总结 + print_section("✅ 测试总结") + print("\n所有测试通过!✨") + print("\n关键验证点:") + print(" ✓ 图像以字典形式存储") + print(" ✓ 每个摄像头独立的 key") + print(" ✓ Policy 在 forward 时 stack 图像") + print(" ✓ 时间维度正确 (obs_horizon, pred_horizon)") + print(" ✓ Padding 处理正确") + print(" ✓ DataLoader 集成正确") + print(" ✓ Task 信息传递正确") + print("\n与 LeRobotDataset 设计完全一致!🎉") + + +if __name__ == "__main__": + from torch.utils.data import DataLoader + run_all_tests() \ No newline at end of file