diff --git a/src/diffusion/flow_matching/training_disperse.py b/src/diffusion/flow_matching/training_disperse.py index a40c4ae..03a8b6d 100644 --- a/src/diffusion/flow_matching/training_disperse.py +++ b/src/diffusion/flow_matching/training_disperse.py @@ -79,10 +79,17 @@ class DisperseTrainer(BaseTrainer): out = net(x_t, t, y) handle.remove() disperse_loss = 0.0 + world_size = torch.distributed.get_world_size() for sf in src_feature: - sf = sf.view(batch_size, -1) - distance = torch.nn.functional.pdist(sf, p=2)**2 - sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5 + gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)] + torch.distributed.all_gather(gathered_sf, sf) + gathered_sf = torch.cat(gathered_sf, dim=0) + sf = gathered_sf.view(batch_size * world_size, -1) + sf = sf.view(batch_size * world_size, -1) + # normalize sf + sf = sf / torch.norm(sf, dim=1, keepdim=True) + distance = torch.nn.functional.pdist(sf, p=2) ** 2 + sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5 disperse_loss += sf_disperse_distance.mean().log()