From 3093a651519873883aacb956191e01073136eacc Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Thu, 12 Jun 2025 22:18:15 +0800 Subject: [PATCH] disperse loss --- .../flow_matching/training_disperse.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 src/diffusion/flow_matching/training_disperse.py diff --git a/src/diffusion/flow_matching/training_disperse.py b/src/diffusion/flow_matching/training_disperse.py new file mode 100644 index 0000000..bf8b5d0 --- /dev/null +++ b/src/diffusion/flow_matching/training_disperse.py @@ -0,0 +1,101 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +def time_shift_fn(t, timeshift=1.0): + return t/(t+(1-t)*timeshift) + + +class DisperseTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + timeshift=1.0, + align_layer=8, + temperature=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.timeshift = timeshift + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.temperature = temperature + + def _impl_trainstep(self, net, ema_net, solver, x, y, metadata=None): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32) + t = time_shift_fn(base_t, self.timeshift).to(x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + feature = output + if isinstance(feature, tuple): + feature = feature[0] # mmdit + src_feature.append(feature) + + if getattr(net, "encoder", None) is not None: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + handle.remove() + disperse_distance = 0.0 + for sf in src_feature: + sf = torch.mean(sf, dim=1, keepdim=False) + distance = (sf[None, :, :] - sf[:, None, :])**2 + distance = distance.sum(dim=-1) + sf_disperse_loss = torch.exp(-distance/self.temperature) + mask = 1-torch.eye(batch_size, device=distance.device, dtype=distance.dtype) + disperse_distance += (sf_disperse_loss*mask).sum()/mask.numel() + 1e-6 + disperse_loss = disperse_distance.log() + + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=disperse_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(), + ) + return out +