添加pad_loss
This commit is contained in:
@@ -248,7 +248,8 @@ def main(cfg: DictConfig):
|
|||||||
return {
|
return {
|
||||||
'images': images,
|
'images': images,
|
||||||
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
'qpos': batch_data['observation.state'], # SimpleRobotDataset 使用 observation.state
|
||||||
'action': batch_data['action']
|
'action': batch_data['action'],
|
||||||
|
'action_is_pad': batch_data.get('action_is_pad', None) # 传递padding mask
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_validation():
|
def run_validation():
|
||||||
|
|||||||
@@ -87,9 +87,10 @@ class VLAAgent(nn.Module):
|
|||||||
计算训练损失
|
计算训练损失
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: 包含 images, qpos (本体感知), action 的字典
|
batch: 包含 images, qpos (本体感知), action, action_is_pad 的字典
|
||||||
"""
|
"""
|
||||||
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
actions, states, images = batch['action'], batch['qpos'], batch['images']
|
||||||
|
action_is_pad = batch.get('action_is_pad', None) # 获取padding mask
|
||||||
B = actions.shape[0]
|
B = actions.shape[0]
|
||||||
|
|
||||||
# 归一化 states (qpos) 和 actions
|
# 归一化 states (qpos) 和 actions
|
||||||
@@ -131,8 +132,17 @@ class VLAAgent(nn.Module):
|
|||||||
global_cond=global_cond
|
global_cond=global_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 计算 Loss (MSE)
|
# 6. 计算 Loss (MSE),支持 padding mask
|
||||||
loss = nn.functional.mse_loss(pred_noise, noise)
|
loss = nn.functional.mse_loss(pred_noise, noise, reduction='none')
|
||||||
|
|
||||||
|
# 如果提供了 action_is_pad,对padding位置进行mask
|
||||||
|
if action_is_pad is not None:
|
||||||
|
# action_is_pad: (B, pred_horizon),扩展到 (B, pred_horizon, action_dim)
|
||||||
|
mask = ~action_is_pad.unsqueeze(-1) # True表示有效数据
|
||||||
|
loss = (loss * mask).sum() / mask.sum() # 只对有效位置计算平均
|
||||||
|
else:
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
# ==========================
|
# ==========================
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ defaults:
|
|||||||
# ====================
|
# ====================
|
||||||
train:
|
train:
|
||||||
# 基础训练参数
|
# 基础训练参数
|
||||||
batch_size: 32 # 批次大小
|
batch_size: 8 # 批次大小
|
||||||
lr: 1e-4 # 学习率
|
lr: 1e-4 # 学习率
|
||||||
max_steps: 100000 # 最大训练步数
|
max_steps: 100000 # 最大训练步数
|
||||||
device: "cuda" # 设备: "cuda" 或 "cpu"
|
device: "cuda" # 设备: "cuda" 或 "cpu"
|
||||||
|
|
||||||
# 数据加载
|
# 数据加载
|
||||||
num_workers: 40 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
num_workers: 8 # DataLoader 工作进程数(调试时设为 0,生产环境用 8)
|
||||||
val_split: 0.1 # 验证集比例
|
val_split: 0.1 # 验证集比例
|
||||||
seed: 42 # 随机种子(用于数据划分)
|
seed: 42 # 随机种子(用于数据划分)
|
||||||
|
|
||||||
|
|||||||
@@ -86,8 +86,8 @@ class SimpleRobotDataset(Dataset):
|
|||||||
h5_path = f'observations/images/{cam_name}'
|
h5_path = f'observations/images/{cam_name}'
|
||||||
if h5_path in f:
|
if h5_path in f:
|
||||||
img = f[h5_path][meta["frame_idx"]]
|
img = f[h5_path][meta["frame_idx"]]
|
||||||
img = torch.from_numpy(img)
|
# 转换为float并归一化到 [0, 1]
|
||||||
# 保持 uint8 格式以节省传输带宽,归一化移至 GPU (在 train_vla.py 中处理)
|
img = torch.from_numpy(img).float() / 255.0
|
||||||
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
frame[f"observation.{cam_name}"] = img.permute(2, 0, 1) # HWC -> CHW
|
||||||
|
|
||||||
return frame
|
return frame
|
||||||
|
|||||||
@@ -151,9 +151,17 @@ class _SingleRgbEncoder(nn.Module):
|
|||||||
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||||
self.relu = nn.ReLU()
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
# 注册ImageNet标准化参数为buffer(会自动移到GPU)
|
||||||
|
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||||
|
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||||
|
|
||||||
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if self.do_crop:
|
if self.do_crop:
|
||||||
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
|
||||||
|
|
||||||
|
# ImageNet标准化(预训练权重期望的输入分布)
|
||||||
|
x = (x - self.mean) / self.std
|
||||||
|
|
||||||
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user