disperse loss

This commit is contained in:
wangshuai6
2025-06-14 12:14:01 +08:00
parent 598f7b40a2
commit 8038c16bee

View File

@@ -79,10 +79,17 @@ class DisperseTrainer(BaseTrainer):
out = net(x_t, t, y) out = net(x_t, t, y)
handle.remove() handle.remove()
disperse_loss = 0.0 disperse_loss = 0.0
world_size = torch.distributed.get_world_size()
for sf in src_feature: for sf in src_feature:
sf = sf.view(batch_size, -1) gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
distance = torch.nn.functional.pdist(sf, p=2)**2 torch.distributed.all_gather(gathered_sf, sf)
sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5 gathered_sf = torch.cat(gathered_sf, dim=0)
sf = gathered_sf.view(batch_size * world_size, -1)
sf = sf.view(batch_size * world_size, -1)
# normalize sf
sf = sf / torch.norm(sf, dim=1, keepdim=True)
distance = torch.nn.functional.pdist(sf, p=2) ** 2
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
disperse_loss += sf_disperse_distance.mean().log() disperse_loss += sf_disperse_distance.mean().log()