From 8038c16bee9567aa2c46499b3d06c5455ff99562 Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Sat, 14 Jun 2025 12:14:01 +0800 Subject: [PATCH] disperse loss --- src/diffusion/flow_matching/training_disperse.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusion/flow_matching/training_disperse.py b/src/diffusion/flow_matching/training_disperse.py index a40c4ae..03a8b6d 100644 --- a/src/diffusion/flow_matching/training_disperse.py +++ b/src/diffusion/flow_matching/training_disperse.py @@ -79,10 +79,17 @@ class DisperseTrainer(BaseTrainer): out = net(x_t, t, y) handle.remove() disperse_loss = 0.0 + world_size = torch.distributed.get_world_size() for sf in src_feature: - sf = sf.view(batch_size, -1) - distance = torch.nn.functional.pdist(sf, p=2)**2 - sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5 + gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)] + torch.distributed.all_gather(gathered_sf, sf) + gathered_sf = torch.cat(gathered_sf, dim=0) + sf = gathered_sf.view(batch_size * world_size, -1) + sf = sf.view(batch_size * world_size, -1) + # normalize sf + sf = sf / torch.norm(sf, dim=1, keepdim=True) + distance = torch.nn.functional.pdist(sf, p=2) ** 2 + sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5 disperse_loss += sf_disperse_distance.mean().log()