添加pad_loss

This commit is contained in:
gouhanke
2026-02-11 20:33:26 +08:00
parent eeb07cad15
commit 83cd55e67b
5 changed files with 27 additions and 8 deletions

View File

@@ -248,7 +248,8 @@ def main(cfg: DictConfig):
return {
'images': images,
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
'action': batch_data['action']
'action': batch_data['action'],
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
}
def run_validation():

View File

@@ -87,9 +87,10 @@ class VLAAgent(nn.Module):
计算训练损失
Args:
batch: 包含 images, qpos (本体感知), action 的字典
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
"""
actions, states, images = batch['action'], batch['qpos'], batch['images']
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
B = actions.shape[0]
# 归一化 states (qpos) 和 actions
@@ -131,8 +132,17 @@ class VLAAgent(nn.Module):
global_cond=global_cond
)
# 6. 计算 Loss (MSE)
loss = nn.functional.mse_loss(pred_noise, noise)
# 6. 计算 Loss (MSE),支持 padding mask
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
# 如果提供了 action_is_pad对padding位置进行mask
if action_is_pad is not None:
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
else:
loss = loss.mean()
return loss
# ==========================

View File

@@ -9,13 +9,13 @@ defaults:
# ====================
train:
# 基础训练参数
batch_size: 32 # 批次大小
batch_size: 8 # 批次大小
lr: 1e-4 # 学习率
max_steps: 100000 # 最大训练步数
device: "cuda" # 设备: "cuda" 或 "cpu"
# 数据加载
num_workers: 40 # DataLoader 工作进程数(调试时设为 0生产环境用 8
num_workers: 8 # DataLoader 工作进程数(调试时设为 0生产环境用 8
val_split: 0.1 # 验证集比例
seed: 42 # 随机种子(用于数据划分)

View File

@@ -86,8 +86,8 @@ class SimpleRobotDataset(Dataset):
h5_path = f'observations/images/{cam_name}'
if h5_path in f:
img = f[h5_path][meta["frame_idx"]]
img = torch.from_numpy(img)
# 保持 uint8 格式以节省传输带宽,归一化移至 GPU (在 train_vla.py 中处理)
# 转换为float并归一化到 [0, 1]
img = torch.from_numpy(img).float() / 255.0
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
return frame

View File

@@ -151,9 +151,17 @@ class _SingleRgbEncoder(nn.Module):
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
# 注册ImageNet标准化参数为buffer会自动移到GPU
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
if self.do_crop:
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
# ImageNet标准化预训练权重期望的输入分布
x = (x - self.mean) / self.std
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x