disperse loss
This commit is contained in:
@@ -48,7 +48,7 @@ class DisperseTrainer(BaseTrainer):
|
|||||||
self.align_layer = align_layer
|
self.align_layer = align_layer
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
|
||||||
def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None):
|
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||||
batch_size, c, height, width = x.shape
|
batch_size, c, height, width = x.shape
|
||||||
if self.lognorm_t:
|
if self.lognorm_t:
|
||||||
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
|
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
|
||||||
@@ -78,15 +78,12 @@ class DisperseTrainer(BaseTrainer):
|
|||||||
|
|
||||||
out = net(x_t, t, y)
|
out = net(x_t, t, y)
|
||||||
handle.remove()
|
handle.remove()
|
||||||
disperse_distance = 0.0
|
disperse_loss = 0.0
|
||||||
for sf in src_feature:
|
for sf in src_feature:
|
||||||
sf = torch.mean(sf, dim=1, keepdim=False)
|
sf = sf.view(batch_size, -1)
|
||||||
distance = (sf[None, :, :] - sf[:, None, :])**2
|
distance = torch.nn.functional.pdist(sf, p=2)**2
|
||||||
distance = distance.sum(dim=-1)
|
sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5
|
||||||
sf_disperse_loss = torch.exp(-distance/self.temperature)
|
disperse_loss += sf_disperse_distance.mean().log()
|
||||||
mask = 1-torch.eye(batch_size, device=distance.device, dtype=distance.dtype)
|
|
||||||
disperse_distance += (sf_disperse_loss*mask).sum()/mask.numel() + 1e-6
|
|
||||||
disperse_loss = disperse_distance.log()
|
|
||||||
|
|
||||||
|
|
||||||
weight = self.loss_weight_fn(alpha, sigma)
|
weight = self.loss_weight_fn(alpha, sigma)
|
||||||
@@ -94,8 +91,7 @@ class DisperseTrainer(BaseTrainer):
|
|||||||
|
|
||||||
out = dict(
|
out = dict(
|
||||||
fm_loss=fm_loss.mean(),
|
fm_loss=fm_loss.mean(),
|
||||||
cos_loss=disperse_loss.mean(),
|
disperse_loss=disperse_loss.mean(),
|
||||||
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
|
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user