disperse loss
This commit is contained in:
@@ -79,10 +79,17 @@ class DisperseTrainer(BaseTrainer):
|
|||||||
out = net(x_t, t, y)
|
out = net(x_t, t, y)
|
||||||
handle.remove()
|
handle.remove()
|
||||||
disperse_loss = 0.0
|
disperse_loss = 0.0
|
||||||
|
world_size = torch.distributed.get_world_size()
|
||||||
for sf in src_feature:
|
for sf in src_feature:
|
||||||
sf = sf.view(batch_size, -1)
|
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
|
||||||
distance = torch.nn.functional.pdist(sf, p=2)**2
|
torch.distributed.all_gather(gathered_sf, sf)
|
||||||
sf_disperse_distance = torch.exp(-distance/self.temperature) + 1e-5
|
gathered_sf = torch.cat(gathered_sf, dim=0)
|
||||||
|
sf = gathered_sf.view(batch_size * world_size, -1)
|
||||||
|
sf = sf.view(batch_size * world_size, -1)
|
||||||
|
# normalize sf
|
||||||
|
sf = sf / torch.norm(sf, dim=1, keepdim=True)
|
||||||
|
distance = torch.nn.functional.pdist(sf, p=2) ** 2
|
||||||
|
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
|
||||||
disperse_loss += sf_disperse_distance.mean().log()
|
disperse_loss += sf_disperse_distance.mean().log()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user