添加pad_loss
This commit is contained in:
@@ -248,7 +248,8 @@ def main(cfg: DictConfig):
|
||||
return {
|
||||
'images': images,
|
||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||
'action': batch_data['action']
|
||||
'action': batch_data['action'],
|
||||
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||
}
|
||||
|
||||
def run_validation():
|
||||
|
||||
@@ -87,9 +87,10 @@ class VLAAgent(nn.Module):
|
||||
计算训练损失
|
||||
|
||||
Args:
|
||||
batch: 包含 images, qpos (本体感知), action 的字典
|
||||
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
|
||||
"""
|
||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
|
||||
B = actions.shape[0]
|
||||
|
||||
# 归一化 states (qpos) 和 actions
|
||||
@@ -131,8 +132,17 @@ class VLAAgent(nn.Module):
|
||||
global_cond=global_cond
|
||||
)
|
||||
|
||||
# 6. 计算 Loss (MSE)
|
||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
||||
# 6. 计算 Loss (MSE),支持 padding mask
|
||||
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||||
|
||||
# 如果提供了 action_is_pad,对padding位置进行mask
|
||||
if action_is_pad is not None:
|
||||
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||||
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
|
||||
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
return loss
|
||||
|
||||
# ==========================
|
||||
|
||||
@@ -9,13 +9,13 @@ defaults:
|
||||
# ====================
|
||||
train:
|
||||
# 基础训练参数
|
||||
batch_size: 32 # 批次大小
|
||||
batch_size: 8 # 批次大小
|
||||
lr: 1e-4 # 学习率
|
||||
max_steps: 100000 # 最大训练步数
|
||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||
|
||||
# 数据加载
|
||||
num_workers: 40 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||
val_split: 0.1 # 验证集比例
|
||||
seed: 42 # 随机种子(用于数据划分)
|
||||
|
||||
|
||||
@@ -86,8 +86,8 @@ class SimpleRobotDataset(Dataset):
|
||||
h5_path = f'observations/images/{cam_name}'
|
||||
if h5_path in f:
|
||||
img = f[h5_path][meta["frame_idx"]]
|
||||
img = torch.from_numpy(img)
|
||||
# 保持 uint8 格式以节省传输带宽,归一化移至 GPU (在 train_vla.py 中处理)
|
||||
# 转换为float并归一化到 [0, 1]
|
||||
img = torch.from_numpy(img).float() / 255.0
|
||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||
|
||||
return frame
|
||||
|
||||
@@ -151,9 +151,17 @@ class _SingleRgbEncoder(nn.Module):
|
||||
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
# 注册ImageNet标准化参数为buffer(会自动移到GPU)
|
||||
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.do_crop:
|
||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||
|
||||
# ImageNet标准化(预训练权重期望的输入分布)
|
||||
x = (x - self.mean) / self.std
|
||||
|
||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user