diff --git a/src/models/denoiser/improved_dit.py b/src/models/denoiser/improved_dit.py index 99e2f5a..0ca307b 100644 --- a/src/models/denoiser/improved_dit.py +++ b/src/models/denoiser/improved_dit.py @@ -251,11 +251,11 @@ class DiT(nn.Module): self.initialize_weights() self.precompute_pos = dict() - def fetch_pos(self, height, width, device, dtype): + def fetch_pos(self, height, width, device): if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device, dtype) + return self.precompute_pos[(height, width)].to(device) else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) self.precompute_pos[(height, width)] = pos return pos @@ -289,7 +289,7 @@ class DiT(nn.Module): B, _, H, W = x.shape x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) x = self.x_embedder(x) - pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) B, L, C = x.shape t = self.t_embedder(t.view(-1)).view(B, -1, C) y = self.y_embedder(y).view(B, 1, C) @@ -298,4 +298,4 @@ class DiT(nn.Module): x = block(x, condition, pos, masks[i]) x = self.final_layer(x, condition) x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x \ No newline at end of file + return x