Update improved_dit.py
This commit is contained in:
@@ -251,11 +251,11 @@ class DiT(nn.Module):
|
|||||||
self.initialize_weights()
|
self.initialize_weights()
|
||||||
self.precompute_pos = dict()
|
self.precompute_pos = dict()
|
||||||
|
|
||||||
def fetch_pos(self, height, width, device, dtype):
|
def fetch_pos(self, height, width, device):
|
||||||
if (height, width) in self.precompute_pos:
|
if (height, width) in self.precompute_pos:
|
||||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
return self.precompute_pos[(height, width)].to(device)
|
||||||
else:
|
else:
|
||||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||||
self.precompute_pos[(height, width)] = pos
|
self.precompute_pos[(height, width)] = pos
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ class DiT(nn.Module):
|
|||||||
B, _, H, W = x.shape
|
B, _, H, W = x.shape
|
||||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||||
x = self.x_embedder(x)
|
x = self.x_embedder(x)
|
||||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||||
B, L, C = x.shape
|
B, L, C = x.shape
|
||||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||||
y = self.y_embedder(y).view(B, 1, C)
|
y = self.y_embedder(y).view(B, 1, C)
|
||||||
@@ -298,4 +298,4 @@ class DiT(nn.Module):
|
|||||||
x = block(x, condition, pos, masks[i])
|
x = block(x, condition, pos, masks[i])
|
||||||
x = self.final_layer(x, condition)
|
x = self.final_layer(x, condition)
|
||||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||||
return x
|
return x
|
||||||
|
|||||||
Reference in New Issue
Block a user