From 598f7b40a2212912d0d2d60136aa9d190bbbaa33 Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Fri, 13 Jun 2025 10:41:31 +0800 Subject: [PATCH] disperse loss --- .../flow_matching/training_disperse.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusion/flow_matching/training_disperse.py b/src/diffusion/flow_matching/training_disperse.py index bf8b5d0..a40c4ae 100644 --- a/src/diffusion/flow_matching/training_disperse.py +++ b/src/diffusion/flow_matching/training_disperse.py @@ -48,7 +48,7 @@ class DisperseTrainer(BaseTrainer): self.align_layer = align_layer self.temperature = temperature - def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + def _impl_trainstep(self, net, ema_net, raw_images, x, y): batch_size, c, height, width = x.shape if self.lognorm_t: base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid() @@ -78,15 +78,12 @@ class DisperseTrainer(BaseTrainer): out = net(x_t, t, y) handle.remove() - disperse_distance = 0.0 + disperse_loss = 0.0 for sf in src_feature: - sf = torch.mean(sf, dim=1, keepdim=False) - distance = (sf[None, :, :] - sf[:, None, :])**2 - distance = distance.sum(dim=-1) - sf_disperse_loss = torch.exp(-distance/self.temperature) - mask = 1-torch.eye(batch_size, device=distance.device, dtype=distance.dtype) - disperse_distance += (sf_disperse_loss*mask).sum()/mask.numel() + 1e-6 - disperse_loss = disperse_distance.log() + sf = sf.view(batch_size, -1) + distance = torch.nn.functional.pdist(sf, p=2)**2 + sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5 + disperse_loss += sf_disperse_distance.mean().log() weight = self.loss_weight_fn(alpha, sigma) @@ -94,8 +91,7 @@ class DisperseTrainer(BaseTrainer): out = dict( fm_loss=fm_loss.mean(), - cos_loss=disperse_loss.mean(), + disperse_loss=disperse_loss.mean(), loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(), ) return out -