添加pad_loss

This commit is contained in:
gouhanke
2026-02-11 20:33:26 +08:00
parent eeb07cad15
commit 83cd55e67b
5 changed files with 27 additions and 8 deletions

View File

@@ -151,9 +151,17 @@ class _SingleRgbEncoder(nn.Module):
self.out = nn.Linear(spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
# 注册ImageNet标准化参数为buffer会自动移到GPU
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward_single_image(self, x: torch.Tensor) -> torch.Tensor:
if self.do_crop:
x = self.maybe_random_crop(x) if self.training else self.center_crop(x)
# ImageNet标准化预训练权重期望的输入分布
x = (x - self.mean) / self.std
x = self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
return x