From 83cd55e67b868b13132782de2cb169c3fffa5536 Mon Sep 17 00:00:00 2001 From: gouhanke <12219217+gouhanke@user.noreply.gitee.com> Date: Wed, 11 Feb 2026 20:33:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0pad=5Floss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roboimi/demos/vla_scripts/train_vla.py | 3 ++- roboimi/vla/agent.py | 16 +++++++++++++--- roboimi/vla/conf/config.yaml | 4 ++-- roboimi/vla/data/simpe_robot_dataset.py | 4 ++-- roboimi/vla/models/backbones/resnet_diffusion.py | 8 ++++++++ 5 files changed, 27 insertions(+), 8 deletions(-) diff --git a/roboimi/demos/vla_scripts/train_vla.py b/roboimi/demos/vla_scripts/train_vla.py index 13c91bd..f5fbcb1 100644 --- a/roboimi/demos/vla_scripts/train_vla.py +++ b/roboimi/demos/vla_scripts/train_vla.py @@ -248,7 +248,8 @@ def main(cfg: DictConfig): return { 'images': images, 'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state - 'action': batch_data['action'] + 'action': batch_data['action'], + 'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask } def run_validation(): diff --git a/roboimi/vla/agent.py b/roboimi/vla/agent.py index 1172f9e..c1ac1cd 100644 --- a/roboimi/vla/agent.py +++ b/roboimi/vla/agent.py @@ -87,9 +87,10 @@ class VLAAgent(nn.Module): 计算训练损失 Args: - batch: 包含 images, qpos (本体感知), action 的字典 + batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典 """ actions, states, images = batch['action'], batch['qpos'], batch['images'] + action_is_pad = batch.get('action_is_pad', None) # 获取padding mask B = actions.shape[0] # 归一化 states (qpos) 和 actions @@ -131,8 +132,17 @@ class VLAAgent(nn.Module): global_cond=global_cond ) - # 6. 计算 Loss (MSE) - loss = nn.functional.mse_loss(pred_noise, noise) + # 6. 计算 Loss (MSE),支持 padding mask + loss = nn.functional.mse_loss(pred_noise, noise, reduction='none') + + # 如果提供了 action_is_pad,对padding位置进行mask + if action_is_pad is not None: + # action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim) + mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据 + loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均 + else: + loss = loss.mean() + return loss # ========================== diff --git a/roboimi/vla/conf/config.yaml b/roboimi/vla/conf/config.yaml index 2072ed7..b4cf8c0 100644 --- a/roboimi/vla/conf/config.yaml +++ b/roboimi/vla/conf/config.yaml @@ -9,13 +9,13 @@ defaults: # ==================== train: # 基础训练参数 - batch_size: 32 # 批次大小 + batch_size: 8 # 批次大小 lr: 1e-4 # 学习率 max_steps: 100000 # 最大训练步数 device: "cuda" # 设备: "cuda" 或 "cpu" # 数据加载 - num_workers: 40 # DataLoader 工作进程数(调试时设为 0,生产环境用 8) + num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8) val_split: 0.1 # 验证集比例 seed: 42 # 随机种子(用于数据划分) diff --git a/roboimi/vla/data/simpe_robot_dataset.py b/roboimi/vla/data/simpe_robot_dataset.py index ca690f4..7650a37 100644 --- a/roboimi/vla/data/simpe_robot_dataset.py +++ b/roboimi/vla/data/simpe_robot_dataset.py @@ -86,8 +86,8 @@ class SimpleRobotDataset(Dataset): h5_path = f'observations/images/{cam_name}' if h5_path in f: img = f[h5_path][meta["frame_idx"]] - img = torch.from_numpy(img) - # 保持 uint8 格式以节省传输带宽,归一化移至 GPU (在 train_vla.py 中处理) + # 转换为float并归一化到 [0, 1] + img = torch.from_numpy(img).float() / 255.0 frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW return frame diff --git a/roboimi/vla/models/backbones/resnet_diffusion.py b/roboimi/vla/models/backbones/resnet_diffusion.py index 695496d..b5c898f 100644 --- a/roboimi/vla/models/backbones/resnet_diffusion.py +++ b/roboimi/vla/models/backbones/resnet_diffusion.py @@ -151,9 +151,17 @@ class _SingleRgbEncoder(nn.Module): self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() + # 注册ImageNet标准化参数为buffer(会自动移到GPU) + self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + def forward_single_image(self, x: torch.Tensor) -> torch.Tensor: if self.do_crop: x = self.maybe_random_crop(x) if self.training else self.center_crop(x) + + # ImageNet标准化(预训练权重期望的输入分布) + x = (x - self.mean) / self.std + x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) return x