Update improved_dit.py

This commit is contained in:
wang shuai
2025-07-03 20:21:04 +08:00
committed by GitHub
parent 8038c16bee
commit 2c16b8f423

View File

@@ -251,11 +251,11 @@ class DiT(nn.Module):
self.initialize_weights() self.initialize_weights()
self.precompute_pos = dict() self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype): def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos: if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype) return self.precompute_pos[(height, width)].to(device)
else: else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos self.precompute_pos[(height, width)] = pos
return pos return pos
@@ -289,7 +289,7 @@ class DiT(nn.Module):
B, _, H, W = x.shape B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x) x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
B, L, C = x.shape B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C) t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C) y = self.y_embedder(y).view(B, 1, C)