diff --git a/.idea/DDT.iml b/.idea/DDT.iml
index d0876a7..8b8c395 100644
--- a/.idea/DDT.iml
+++ b/.idea/DDT.iml
@@ -5,4 +5,8 @@
+
+
+
+
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/training_pyramid.py b/src/diffusion/flow_matching/training_pyramid.py
deleted file mode 100644
index be2bd94..0000000
--- a/src/diffusion/flow_matching/training_pyramid.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import torch
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class PyramidTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
-
- output_pyramid = []
- def feature_hook(module, input, output):
- output_pyramid.extend(output)
- handle = net.decoder.register_forward_hook(feature_hook)
- net(x_t, t, y)
- handle.remove()
-
- loss = 0.0
- out_dict = dict()
-
- cur_v_t = v_t
- for i in range(len(output_pyramid)):
- cur_out = output_pyramid[i]
- loss_i = (cur_v_t - cur_out) ** 2
- loss += loss_i.mean()
- out_dict["loss_{}".format(i)] = loss_i.mean()
- cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
- out_dict["loss"] = loss
- return out_dict
-
diff --git a/src/diffusion/flow_matching/training_repa_mask.py b/src/diffusion/flow_matching/training_repa_mask.py
deleted file mode 100644
index f8c4edb..0000000
--- a/src/diffusion/flow_matching/training_repa_mask.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import torch
-import copy
-import timm
-from torch.nn import Parameter
-
-from src.utils.no_grad import no_grad
-from typing import Callable, Iterator, Tuple
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from torchvision.transforms import Normalize
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class DINOv2(nn.Module):
- def __init__(self, weight_path:str):
- super(DINOv2, self).__init__()
- self.encoder = torch.hub.load(
- '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
- weight_path,
- source="local",
- skip_validation=True
- )
- self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
- self.encoder.head = torch.nn.Identity()
- self.patch_size = self.encoder.patch_embed.patch_size
- self.precomputed_pos_embed = dict()
-
- def fetch_pos(self, h, w):
- key = (h, w)
- if key in self.precomputed_pos_embed:
- return self.precomputed_pos_embed[key]
- value = timm.layers.pos_embed.resample_abs_pos_embed(
- self.pos_embed.data, [h, w],
- )
- self.precomputed_pos_embed[key] = value
- return value
-
- def forward(self, x):
- b, c, h, w = x.shape
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
- x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
- b, c, h, w = x.shape
- patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
- pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
- self.encoder.pos_embed.data = pos_embed_data
- feature = self.encoder.forward_features(x)['x_norm_patchtokens']
- return feature
-
-
-class REPATrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- feat_loss_weight: float=0.5,
- lognorm_t=False,
- mask_ratio=0.0,
- mask_patch_size=2,
- encoder_weight_path=None,
- align_layer=8,
- proj_denoiser_dim=256,
- proj_hidden_dim=256,
- proj_encoder_dim=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.mask_ratio = mask_ratio
- self.mask_patch_size = mask_patch_size
- self.feat_loss_weight = feat_loss_weight
- self.align_layer = align_layer
- self.encoder = DINOv2(encoder_weight_path)
- no_grad(self.encoder)
-
- self.proj = nn.Sequential(
- nn.Sequential(
- nn.Linear(proj_denoiser_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_encoder_dim),
- )
- )
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size, c, height, width = x.shape
- if self.lognorm_t:
- base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
- else:
- base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
- t = base_t
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
-
- patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
- patch_mask = (patch_mask < self.mask_ratio).float()
- mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
- masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
-
- x_t = alpha*masked_x + sigma*noise
- v_t = dalpha*x + dsigma*noise
-
- src_feature = []
- def forward_hook(net, input, output):
- src_feature.append(output)
- handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
-
- v_t_out, x0_out = net(x_t, t, y)
- src_feature = self.proj(src_feature[0])
- handle.remove()
-
- with torch.no_grad():
- dst_feature = self.encoder(raw_images)
-
- cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
- cos_loss = 1 - cos_sim
-
- weight = self.loss_weight_fn(alpha, sigma)
- fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
- mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
- out = dict(
- fm_loss=fm_loss.mean(),
- cos_loss=cos_loss.mean(),
- mask_loss=mask_loss.mean(),
- loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- self.proj.state_dict(
- destination=destination,
- prefix=prefix + "proj.",
- keep_vars=keep_vars)
-
diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv.py b/src/diffusion/stateful_flow_matching/bak/training_adv.py
deleted file mode 100644
index 4792950..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_adv.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class Discriminator(nn.Module):
- def __init__(self, in_channels, hidden_size):
- super().__init__()
- self.head = nn.Sequential(
- nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
- )
-
- def forward(self, feature):
- B, L, C = feature.shape
- H = W = int(math.sqrt(L))
- feature = feature.permute(0, 2, 1)
- feature = feature.view(B, C, H, W)
- out = self.head(feature).sigmoid().clamp(0.01, 0.99)
- return out
-
-class AdvTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- adv_weight=1.0,
- adv_encoder_layer=4,
- adv_in_channels=768,
- adv_hidden_size=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.adv_weight = adv_weight
- self.adv_encoder_layer = adv_encoder_layer
-
- self.dis_head = Discriminator(
- in_channels=adv_in_channels,
- hidden_size=adv_hidden_size,
- )
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- adv_feature = []
- def forward_hook(net, input, output):
- adv_feature.append(output)
- handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
-
- out, _ = net(x_t, t, y)
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
-
- pred_x0 = (x_t + out * sigma)
- pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
- real_feature = adv_feature.pop()
- net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
- fake_feature = adv_feature.pop()
- handle.remove()
-
-
- real_score_gan = self.dis_head(real_feature.detach())
- fake_score_gan = self.dis_head(fake_feature.detach())
- fake_score = self.dis_head(fake_feature)
-
- loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
- acc_real = (real_score_gan > 0.5).float()
- acc_fake = (fake_score_gan < 0.5).float()
- loss_adv = -torch.log(fake_score)
- loss_adv_hack = torch.log(fake_score_gan)
-
- out = dict(
- adv_loss=loss_adv.mean(),
- gan_loss=loss_gan.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
- acc_real=acc_real.mean(),
- acc_fake=acc_fake.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
deleted file mode 100644
index 2843c04..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class Discriminator(nn.Module):
- def __init__(self, in_channels, hidden_size):
- super().__init__()
- self.head = nn.Sequential(
- nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
- )
-
- def forward(self, feature):
- B, L, C = feature.shape
- H = W = int(math.sqrt(L))
- feature = feature.permute(0, 2, 1)
- feature = feature.view(B, C, H, W)
- out = self.head(feature).sigmoid().clamp(0.01, 0.99)
- return out
-
-class AdvTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- adv_weight=1.0,
- lpips_weight=1.0,
- adv_encoder_layer=4,
- adv_in_channels=768,
- adv_hidden_size=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.adv_weight = adv_weight
- self.lpips_weight = lpips_weight
- self.adv_encoder_layer = adv_encoder_layer
-
- self.dis_head = Discriminator(
- in_channels=adv_in_channels,
- hidden_size=adv_hidden_size,
- )
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- out, _ = net(x_t, t, y)
- pred_x0 = (x_t + out * sigma)
-
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
- with torch.no_grad():
- _, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
- _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
-
- real_score_gan = self.dis_head(real_features[-1].detach())
- fake_score_gan = self.dis_head(fake_features[-1].detach())
- fake_score = self.dis_head(fake_features[-1])
-
- loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
- acc_real = (real_score_gan > 0.5).float()
- acc_fake = (fake_score_gan < 0.5).float()
- loss_adv = -torch.log(fake_score)
- loss_adv_hack = torch.log(fake_score_gan)
-
- lpips_loss = []
- for r, f in zip(real_features, fake_features):
- r = torch.nn.functional.normalize(r, dim=-1)
- f = torch.nn.functional.normalize(f, dim=-1)
- lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
- lpips_loss = sum(lpips_loss)
-
-
- out = dict(
- adv_loss=loss_adv.mean(),
- gan_loss=loss_gan.mean(),
- lpips_loss=lpips_loss.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
- acc_real=acc_real.mean(),
- acc_fake=acc_fake.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
deleted file mode 100644
index 849ee4b..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
+++ /dev/null
@@ -1,159 +0,0 @@
-import random
-
-import torch
-import copy
-import timm
-from torch.nn import Parameter
-
-from src.utils.no_grad import no_grad
-from typing import Callable, Iterator, Tuple
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from torchvision.transforms import Normalize
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class DINOv2(nn.Module):
- def __init__(self, weight_path:str):
- super(DINOv2, self).__init__()
- self.encoder = torch.hub.load(
- '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
- weight_path,
- source="local",
- skip_validation=True
- )
- self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
- self.encoder.head = torch.nn.Identity()
- self.patch_size = self.encoder.patch_embed.patch_size
- self.precomputed_pos_embed = dict()
-
- def fetch_pos(self, h, w):
- key = (h, w)
- if key in self.precomputed_pos_embed:
- return self.precomputed_pos_embed[key]
- value = timm.layers.pos_embed.resample_abs_pos_embed(
- self.pos_embed.data, [h, w],
- )
- self.precomputed_pos_embed[key] = value
- return value
-
- def forward(self, x):
- b, c, h, w = x.shape
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
- x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
- b, c, h, w = x.shape
- patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
- pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
- self.encoder.pos_embed.data = pos_embed_data
- feature = self.encoder.forward_features(x)['x_norm_patchtokens']
- return feature
-
-
-class MaskREPATrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- feat_loss_weight: float=0.5,
- lognorm_t=False,
- encoder_weight_path=None,
- mask_groups=4,
- align_layer=8,
- proj_denoiser_dim=256,
- proj_hidden_dim=256,
- proj_encoder_dim=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.feat_loss_weight = feat_loss_weight
- self.align_layer = align_layer
- self.mask_groups = mask_groups
- self.encoder = DINOv2(encoder_weight_path)
- no_grad(self.encoder)
-
- self.proj = nn.Sequential(
- nn.Sequential(
- nn.Linear(proj_denoiser_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_encoder_dim),
- )
- )
-
- def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
- mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
- random_seq = torch.randperm(length, device=device)
- for i in range(groups):
- group_start = (length+groups-1)//groups*i
- group_end = (length+groups-1)//groups*(i+1)
- group_random_seq = random_seq[group_start:group_end]
- y, x = torch.meshgrid(group_random_seq, group_random_seq)
- mask[:, y, x] = True
- return mask
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size, c, height, width = x.shape
- if self.lognorm_t:
- base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
- else:
- base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
- t = base_t
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- src_feature = []
- def forward_hook(net, input, output):
- src_feature.append(output)
- handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
- mask_groups = random.randint(1, self.mask_groups)
- mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
- out, _ = net(x_t, t, y, mask=mask)
- src_feature = self.proj(src_feature[0])
- handle.remove()
-
- with torch.no_grad():
- dst_feature = self.encoder(raw_images)
-
- cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
- cos_loss = 1 - cos_sim
-
- weight = self.loss_weight_fn(alpha, sigma)
- fm_loss = weight*(out - v_t)**2
-
- out = dict(
- fm_loss=fm_loss.mean(),
- cos_loss=cos_loss.mean(),
- loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- self.proj.state_dict(
- destination=destination,
- prefix=prefix + "proj.",
- keep_vars=keep_vars)
-
diff --git a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
deleted file mode 100644
index 229680c..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
+++ /dev/null
@@ -1,179 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
- ).to(device=t.device)
- args = t[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t, mul=1000):
- t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-
-class BatchNormWithTimeEmbedding(nn.Module):
- def __init__(self, num_features):
- super().__init__()
- # self.bn = nn.BatchNorm2d(num_features, affine=False)
- self.bn = nn.GroupNorm(16, num_features, affine=False)
- # self.bn = nn.SyncBatchNorm(num_features, affine=False)
- self.embedder = TimestepEmbedder(num_features * 2)
- # nn.init.zeros_(self.embedder.mlp[-1].weight)
- nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
- nn.init.zeros_(self.embedder.mlp[-1].bias)
-
- def forward(self, x, t):
- embed = self.embedder(t)
- embed = embed[:, :, None, None]
- gamma, beta = embed.chunk(2, dim=1)
- gamma = 1.0 + gamma
- normed = self.bn(x)
- out = normed * gamma + beta
- return out
-
-class DisBlock(nn.Module):
- def __init__(self, in_channels, hidden_size):
- super().__init__()
- self.conv = nn.Conv2d(
- kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
- )
- self.norm = BatchNormWithTimeEmbedding(hidden_size)
- self.act = nn.SiLU()
- def forward(self, x, t):
- x = self.conv(x)
- x = self.norm(x, t)
- x = self.act(x)
- return x
-
-
-class Discriminator(nn.Module):
- def __init__(self, num_blocks, in_channels, hidden_size):
- super().__init__()
- self.blocks = nn.ModuleList()
- for i in range(num_blocks):
- self.blocks.append(
- DisBlock(
- in_channels=in_channels,
- hidden_size=hidden_size,
- )
- )
- in_channels = hidden_size
- self.classifier = nn.Conv2d(
- kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
- )
- def forward(self, feature, t):
- B, C, H, W = feature.shape
- for block in self.blocks:
- feature = block(feature, t)
- out = self.classifier(feature).view(B, -1)
- out = out.sigmoid().clamp(0.01, 0.99)
- return out
-
-class AdvTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- adv_weight=1.0,
- adv_blocks=3,
- adv_in_channels=3,
- adv_hidden_size=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.adv_weight = adv_weight
-
- self.discriminator = Discriminator(
- num_blocks=adv_blocks,
- in_channels=adv_in_channels*2,
- hidden_size=adv_hidden_size,
- )
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- out, _ = net(x_t, t, y)
- pred_x0 = x_t + sigma * out
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
-
- real_feature = torch.cat([x_t, x], dim=1)
- fake_feature = torch.cat([x_t, pred_x0], dim=1)
-
- real_score_gan = self.discriminator(real_feature.detach(), t)
- fake_score_gan = self.discriminator(fake_feature.detach(), t)
- fake_score = self.discriminator(fake_feature, t)
-
- loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
- acc_real = (real_score_gan > 0.5).float()
- acc_fake = (fake_score_gan < 0.5).float()
- loss_adv = -torch.log(fake_score)
- loss_adv_hack = torch.log(fake_score_gan)
-
- out = dict(
- adv_loss=loss_adv.mean(),
- gan_loss=loss_gan.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
- acc_real=acc_real.mean(),
- acc_fake=acc_fake.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
deleted file mode 100644
index e84e81f..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
+++ /dev/null
@@ -1,154 +0,0 @@
-import torch
-import copy
-import timm
-from torch.nn import Parameter
-
-from src.utils.no_grad import no_grad
-from typing import Callable, Iterator, Tuple
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from torchvision.transforms import Normalize
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class DINOv2(nn.Module):
- def __init__(self, weight_path:str):
- super(DINOv2, self).__init__()
- self.encoder = torch.hub.load(
- '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
- weight_path,
- source="local",
- skip_validation=True
- )
- self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
- self.encoder.head = torch.nn.Identity()
- self.patch_size = self.encoder.patch_embed.patch_size
- self.precomputed_pos_embed = dict()
-
- def fetch_pos(self, h, w):
- key = (h, w)
- if key in self.precomputed_pos_embed:
- return self.precomputed_pos_embed[key]
- value = timm.layers.pos_embed.resample_abs_pos_embed(
- self.pos_embed.data, [h, w],
- )
- self.precomputed_pos_embed[key] = value
- return value
-
- def forward(self, x):
- b, c, h, w = x.shape
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
- x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
- b, c, h, w = x.shape
- patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
- pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
- self.encoder.pos_embed.data = pos_embed_data
- feature = self.encoder.forward_features(x)['x_norm_patchtokens']
- return feature
-
-
-class REPAJiTTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- feat_loss_weight: float=0.5,
- lognorm_t=False,
- jit_deltas=0.01,
- encoder_weight_path=None,
- align_layer=8,
- proj_denoiser_dim=256,
- proj_hidden_dim=256,
- proj_encoder_dim=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.feat_loss_weight = feat_loss_weight
- self.align_layer = align_layer
- self.jit_deltas = jit_deltas
- self.encoder = DINOv2(encoder_weight_path)
- no_grad(self.encoder)
-
- self.proj = nn.Sequential(
- nn.Sequential(
- nn.Linear(proj_denoiser_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_encoder_dim),
- )
- )
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size, c, height, width = x.shape
- if self.lognorm_t:
- base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
- else:
- base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
- t = base_t
-
- noise = torch.randn_like(x)
-
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
- t2 = torch.clip(t2, 0, 1)
- alpha = self.scheduler.alpha(t2)
- dalpha = self.scheduler.dalpha(t2)
- sigma = self.scheduler.sigma(t2)
- dsigma = self.scheduler.dsigma(t2)
- x_t2 = alpha * x + noise * sigma
- v_t2 = dalpha * x + dsigma * noise
-
- src_feature = []
- def forward_hook(net, input, output):
- src_feature.append(output)
- handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
-
- _, s = net(x_t, t, y, only_s=True)
- out, _ = net(x_t2, t2, y, s)
- src_feature = self.proj(src_feature[0])
- handle.remove()
-
- with torch.no_grad():
- dst_feature = self.encoder(raw_images)
-
- cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
- cos_loss = 1 - cos_sim
-
- weight = self.loss_weight_fn(alpha, sigma)
- fm_loss = weight*(out - v_t2)**2
-
- out = dict(
- fm_loss=fm_loss.mean(),
- cos_loss=cos_loss.mean(),
- loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- self.proj.state_dict(
- destination=destination,
- prefix=prefix + "proj.",
- keep_vars=keep_vars)
-
diff --git a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py
deleted file mode 100644
index d7e741d..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class SelfConsistentTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- lpips_weight=1.0,
- lpips_encoder_layer=4,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.lpips_encoder_layer = lpips_encoder_layer
- self.lpips_weight = lpips_weight
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- real_features = []
- def forward_hook(net, input, output):
- real_features.append(output)
- handles = []
- for i in range(self.lpips_encoder_layer):
- handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
- handles.append(handle)
-
- out, _ = net(x_t, t, y)
-
- for handle in handles:
- handle.remove()
-
- pred_x0 = (x_t + out * sigma)
- pred_xt = alpha * pred_x0 + noise * sigma
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
-
- _, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
-
- lpips_loss = []
- for r, f in zip(real_features, fake_features):
- r = torch.nn.functional.normalize(r, dim=-1)
- f = torch.nn.functional.normalize(f, dim=-1)
- lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
- lpips_loss = sum(lpips_loss)
-
-
- out = dict(
- lpips_loss=lpips_loss.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py
deleted file mode 100644
index 580775b..0000000
--- a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class SelfLPIPSTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- lpips_weight=1.0,
- lpips_encoder_layer=4,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.lpips_encoder_layer = lpips_encoder_layer
- self.lpips_weight = lpips_weight
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- out, _ = net(x_t, t, y)
- pred_x0 = (x_t + out * sigma)
- pred_xt = alpha * pred_x0 + noise * sigma
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
- with torch.no_grad():
- _, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
- _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
-
-
- lpips_loss = []
- for r, f in zip(real_features, fake_features):
- r = torch.nn.functional.normalize(r, dim=-1)
- f = torch.nn.functional.normalize(f, dim=-1)
- lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
- lpips_loss = sum(lpips_loss)
-
-
- out = dict(
- lpips_loss=lpips_loss.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/cm_sampling.py b/src/diffusion/stateful_flow_matching/cm_sampling.py
deleted file mode 100644
index 5254db5..0000000
--- a/src/diffusion/stateful_flow_matching/cm_sampling.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import torch
-
-from src.diffusion.base.guidance import *
-from src.diffusion.base.scheduling import *
-from src.diffusion.base.sampling import *
-
-from typing import Callable
-
-
-def shift_respace_fn(t, shift=3.0):
- return t / (t + (1 - t) * shift)
-
-def ode_step_fn(x, v, dt, s, w):
- return x + v * dt
-
-
-import logging
-logger = logging.getLogger(__name__)
-
-class CMSampler(BaseSampler):
- def __init__(
- self,
- w_scheduler: BaseScheduler = None,
- timeshift=1.0,
- guidance_interval_min: float = 0.0,
- guidance_interval_max: float = 1.0,
- state_refresh_rate=1,
- last_step=None,
- step_fn=None,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.last_step = last_step
- self.timeshift = timeshift
- self.state_refresh_rate = state_refresh_rate
- self.guidance_interval_min = guidance_interval_min
- self.guidance_interval_max = guidance_interval_max
-
- if self.last_step is None or self.num_steps == 1:
- self.last_step = 1.0 / self.num_steps
-
- timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
- timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
- self.timesteps = shift_respace_fn(timesteps, self.timeshift)
-
- assert self.last_step > 0.0
- assert self.scheduler is not None
-
-
- def _impl_sampling(self, net, noise, condition, uncondition):
- """
- sampling process of Euler sampler
- -
- """
- batch_size = noise.shape[0]
- steps = self.timesteps.to(noise.device)
- cfg_condition = torch.cat([uncondition, condition], dim=0)
- x = noise
- state = None
- for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
- cfg_t = t_cur.repeat(batch_size*2)
- cfg_x = torch.cat([x, x], dim=0)
- if i % self.state_refresh_rate == 0:
- state = None
- out, state = net(cfg_x, cfg_t, cfg_condition, state)
- if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
- out = self.guidance_fn(out, self.guidance)
- else:
- out = self.guidance_fn(out, 1.0)
- v = out
-
- x0 = x + v * (1-t_cur)
- alpha_next = self.scheduler.alpha(t_next)
- sigma_next = self.scheduler.sigma(t_next)
- x = alpha_next * x0 + sigma_next * torch.randn_like(x)
- # print(alpha_next, sigma_next)
- return x
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py
deleted file mode 100644
index f372028..0000000
--- a/src/diffusion/stateful_flow_matching/sharing_sampling.py
+++ /dev/null
@@ -1,149 +0,0 @@
-import torch
-
-from src.diffusion.base.guidance import *
-from src.diffusion.base.scheduling import *
-from src.diffusion.base.sampling import *
-
-from typing import Callable
-
-
-def shift_respace_fn(t, shift=3.0):
- return t / (t + (1 - t) * shift)
-
-def ode_step_fn(x, v, dt, s, w):
- return x + v * dt
-
-
-import logging
-logger = logging.getLogger(__name__)
-
-class EulerSampler(BaseSampler):
- def __init__(
- self,
- w_scheduler: BaseScheduler = None,
- timeshift=1.0,
- guidance_interval_min: float = 0.0,
- guidance_interval_max: float = 1.0,
- state_refresh_rate=1,
- step_fn: Callable = ode_step_fn,
- last_step=None,
- last_step_fn: Callable = ode_step_fn,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.step_fn = step_fn
- self.last_step = last_step
- self.last_step_fn = last_step_fn
- self.w_scheduler = w_scheduler
- self.timeshift = timeshift
- self.state_refresh_rate = state_refresh_rate
- self.guidance_interval_min = guidance_interval_min
- self.guidance_interval_max = guidance_interval_max
-
- if self.last_step is None or self.num_steps == 1:
- self.last_step = 1.0 / self.num_steps
-
- timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
- timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
- self.timesteps = shift_respace_fn(timesteps, self.timeshift)
-
- assert self.last_step > 0.0
- assert self.scheduler is not None
- assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
- if self.w_scheduler is not None:
- if self.step_fn == ode_step_fn:
- logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
-
- # init recompute
- self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
- self.recompute_timesteps = list(range(self.num_steps))
-
- def sharing_dp(self, net, noise, condition, uncondition):
- _, C, H, W = noise.shape
- B = 8
- template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
- template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
- template_uncondition = torch.full((B, ), 1000, device=condition.device)
- _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
- states = torch.stack(state_list)
- N, B, L, C = states.shape
- states = states.view(N, B*L, C )
- states = states.permute(1, 0, 2)
- states = torch.nn.functional.normalize(states, dim=-1)
- with torch.autocast(device_type="cuda", dtype=torch.float64):
- sim = torch.bmm(states, states.transpose(1, 2))
- sim = torch.mean(sim, dim=0).cpu()
- error_map = (1-sim).tolist()
-
- # init cum-error
- for i in range(1, self.num_steps):
- for j in range(0, i):
- error_map[i][j] = error_map[i-1][j] + error_map[i][j]
-
- # init dp and force 0 start
- C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
- P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
- for i in range(1, self.num_steps+1):
- C[1][i] = error_map[i - 1][0]
- P[1][i] = 0
-
- # dp state
- for step in range(2, self.num_recompute_timesteps+1):
- for i in range(step, self.num_steps+1):
- min_value = 99999
- min_index = -1
- for j in range(step-1, i):
- value = C[step-1][j] + error_map[i-1][j]
- if value < min_value:
- min_value = value
- min_index = j
- C[step][i] = min_value
- P[step][i] = min_index
-
- # trace back
- timesteps = [self.num_steps,]
- for i in range(self.num_recompute_timesteps, 0, -1):
- idx = timesteps[-1]
- timesteps.append(P[i][idx])
- timesteps.reverse()
-
- print("recompute timesteps solved by DP: ", timesteps)
- return timesteps[:-1]
-
- def _impl_sampling(self, net, noise, condition, uncondition):
- """
- sampling process of Euler sampler
- -
- """
- batch_size = noise.shape[0]
- steps = self.timesteps.to(noise.device)
- cfg_condition = torch.cat([uncondition, condition], dim=0)
- x = noise
- state = None
- pooled_state_list = []
- for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
- dt = t_next - t_cur
- t_cur = t_cur.repeat(batch_size)
- cfg_x = torch.cat([x, x], dim=0)
- cfg_t = t_cur.repeat(2)
- if i in self.recompute_timesteps:
- state = None
- out, state = net(cfg_x, cfg_t, cfg_condition, state)
- if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
- out = self.guidance_fn(out, self.guidance)
- else:
- out = self.guidance_fn(out, 1.0)
- v = out
- if i < self.num_steps -1 :
- x = self.step_fn(x, v, dt, s=0.0, w=0.0)
- else:
- x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
- pooled_state_list.append(state)
- return x, pooled_state_list
-
- def __call__(self, net, noise, condition, uncondition):
- if len(self.recompute_timesteps) != self.num_recompute_timesteps:
- self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
- denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
- return denoised
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_adv.py b/src/diffusion/stateful_flow_matching/training_adv.py
deleted file mode 100644
index 4792950..0000000
--- a/src/diffusion/stateful_flow_matching/training_adv.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import torch
-import math
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class Discriminator(nn.Module):
- def __init__(self, in_channels, hidden_size):
- super().__init__()
- self.head = nn.Sequential(
- nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
- nn.GroupNorm(num_groups=32, num_channels=hidden_size),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d(1),
- nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
- )
-
- def forward(self, feature):
- B, L, C = feature.shape
- H = W = int(math.sqrt(L))
- feature = feature.permute(0, 2, 1)
- feature = feature.view(B, C, H, W)
- out = self.head(feature).sigmoid().clamp(0.01, 0.99)
- return out
-
-class AdvTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- adv_weight=1.0,
- adv_encoder_layer=4,
- adv_in_channels=768,
- adv_hidden_size=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.adv_weight = adv_weight
- self.adv_encoder_layer = adv_encoder_layer
-
- self.dis_head = Discriminator(
- in_channels=adv_in_channels,
- hidden_size=adv_hidden_size,
- )
-
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- adv_feature = []
- def forward_hook(net, input, output):
- adv_feature.append(output)
- handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
-
- out, _ = net(x_t, t, y)
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
-
- pred_x0 = (x_t + out * sigma)
- pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
- real_feature = adv_feature.pop()
- net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
- fake_feature = adv_feature.pop()
- handle.remove()
-
-
- real_score_gan = self.dis_head(real_feature.detach())
- fake_score_gan = self.dis_head(fake_feature.detach())
- fake_score = self.dis_head(fake_feature)
-
- loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
- acc_real = (real_score_gan > 0.5).float()
- acc_fake = (fake_score_gan < 0.5).float()
- loss_adv = -torch.log(fake_score)
- loss_adv_hack = torch.log(fake_score_gan)
-
- out = dict(
- adv_loss=loss_adv.mean(),
- gan_loss=loss_gan.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
- acc_real=acc_real.mean(),
- acc_fake=acc_fake.mean(),
- )
- return out
diff --git a/src/diffusion/stateful_flow_matching/training_distill_dino.py b/src/diffusion/stateful_flow_matching/training_distill_dino.py
deleted file mode 100644
index c6a2937..0000000
--- a/src/diffusion/stateful_flow_matching/training_distill_dino.py
+++ /dev/null
@@ -1,141 +0,0 @@
-import torch
-import copy
-import timm
-from torch.nn import Parameter
-
-from src.utils.no_grad import no_grad
-from typing import Callable, Iterator, Tuple
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from torchvision.transforms import Normalize
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class DINOv2(nn.Module):
- def __init__(self, weight_path:str):
- super(DINOv2, self).__init__()
- self.encoder = torch.hub.load(
- '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
- weight_path,
- source="local",
- skip_validation=True
- )
- self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
- self.encoder.head = torch.nn.Identity()
- self.patch_size = self.encoder.patch_embed.patch_size
- self.precomputed_pos_embed = dict()
-
- def fetch_pos(self, h, w):
- key = (h, w)
- if key in self.precomputed_pos_embed:
- return self.precomputed_pos_embed[key]
- value = timm.layers.pos_embed.resample_abs_pos_embed(
- self.pos_embed.data, [h, w],
- )
- self.precomputed_pos_embed[key] = value
- return value
-
- def forward(self, x):
- b, c, h, w = x.shape
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
- x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
- b, c, h, w = x.shape
- patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
- pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
- self.encoder.pos_embed.data = pos_embed_data
- feature = self.encoder.forward_features(x)['x_norm_patchtokens']
- return feature
-
-
-class DistillDINOTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- feat_loss_weight: float=0.5,
- lognorm_t=False,
- encoder_weight_path=None,
- align_layer=8,
- proj_denoiser_dim=256,
- proj_hidden_dim=256,
- proj_encoder_dim=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.feat_loss_weight = feat_loss_weight
- self.align_layer = align_layer
- self.encoder = DINOv2(encoder_weight_path)
- self.proj_encoder_dim = proj_encoder_dim
- no_grad(self.encoder)
-
- self.proj = nn.Sequential(
- nn.Sequential(
- nn.Linear(proj_denoiser_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_encoder_dim),
- )
- )
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size, c, height, width = x.shape
- if self.lognorm_t:
- base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
- else:
- base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
- t = base_t
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- sigma = self.scheduler.sigma(t)
-
- x_t = alpha * x + noise * sigma
-
- _, s = net(x_t, t, y)
- src_feature = self.proj(s)
-
- with torch.no_grad():
- dst_feature = self.encoder(raw_images)
-
- if dst_feature.shape[1] != src_feature.shape[1]:
- dst_length = dst_feature.shape[1]
- rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
- dst_height = (dst_length)**0.5 * (height/width)**0.5
- dst_width = (dst_length)**0.5 * (width/height)**0.5
- dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
- dst_feature = dst_feature.permute(0, 3, 1, 2)
- dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
- dst_feature = dst_feature.permute(0, 2, 3, 1)
- dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
-
- cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
- cos_loss = 1 - cos_sim
-
- out = dict(
- cos_loss=cos_loss.mean(),
- loss=cos_loss.mean(),
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- self.proj.state_dict(
- destination=destination,
- prefix=prefix + "proj.",
- keep_vars=keep_vars)
-
diff --git a/src/diffusion/stateful_flow_matching/training_lpips.py b/src/diffusion/stateful_flow_matching/training_lpips.py
deleted file mode 100644
index a3cd2a2..0000000
--- a/src/diffusion/stateful_flow_matching/training_lpips.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class LPIPSTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- lpips_weight=1.0,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.lpips_weight = lpips_weight
- self.lpips = _NoTrainLpips(net="vgg")
- self.lpips = self.lpips.to(torch.bfloat16)
- # self.lpips = torch.compile(self.lpips)
- no_grad(self.lpips)
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
- out, _ = net(x_t, t, y)
- weight = self.loss_weight_fn(alpha, sigma)
- loss = weight*(out - v_t)**2
-
- pred_x0 = (x_t + out*sigma)
- target_x0 = x
- # fixbug lpips std
- lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
-
- out = dict(
- lpips_loss=lpips.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + lpips.mean()*self.lpips_weight,
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
- return
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py
deleted file mode 100644
index e0233ea..0000000
--- a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch
-from typing import Callable
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from src.utils.no_grad import no_grad
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-class LPIPSTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- lognorm_t=False,
- lpips_weight=1.0,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = False
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.lpips_weight = lpips_weight
- self.lpips = _NoTrainLpips(net="vgg")
- self.lpips = self.lpips.to(torch.bfloat16)
- # self.lpips = torch.compile(self.lpips)
- no_grad(self.lpips)
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size = x.shape[0]
- if self.lognorm_t:
- t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
- else:
- t = torch.rand(batch_size).to(x.device, x.dtype)
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
- w = self.scheduler.w(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
- out, _ = net(x_t, t, y)
-
- fm_weight = t*(1-t)**2/0.25
- lpips_weight = t
-
- loss = (out - v_t)**2 * fm_weight[:, None, None, None]
-
- pred_x0 = (x_t + out*sigma)
- target_x0 = x
- # fixbug lpips std
- lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
-
- out = dict(
- lpips_loss=lpips.mean(),
- fm_loss=loss.mean(),
- loss=loss.mean() + lpips.mean()*self.lpips_weight,
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
- return
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_repa_lpips.py b/src/diffusion/stateful_flow_matching/training_repa_lpips.py
deleted file mode 100644
index 5a11207..0000000
--- a/src/diffusion/stateful_flow_matching/training_repa_lpips.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import torch
-import copy
-import timm
-from torch.nn import Parameter
-
-from src.utils.no_grad import no_grad
-from typing import Callable, Iterator, Tuple
-from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-from torchvision.transforms import Normalize
-from src.diffusion.base.training import *
-from src.diffusion.base.scheduling import BaseScheduler
-from torchmetrics.image.lpip import _NoTrainLpips
-
-def inverse_sigma(alpha, sigma):
- return 1/sigma**2
-def snr(alpha, sigma):
- return alpha/sigma
-def minsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, min=threshold)
-def maxsnr(alpha, sigma, threshold=5):
- return torch.clip(alpha/sigma, max=threshold)
-def constant(alpha, sigma):
- return 1
-
-
-class DINOv2(nn.Module):
- def __init__(self, weight_path:str):
- super(DINOv2, self).__init__()
- self.encoder = torch.hub.load(
- '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
- weight_path,
- source="local",
- skip_validation=True
- )
- self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
- self.encoder.head = torch.nn.Identity()
- self.patch_size = self.encoder.patch_embed.patch_size
- self.precomputed_pos_embed = dict()
-
- def fetch_pos(self, h, w):
- key = (h, w)
- if key in self.precomputed_pos_embed:
- return self.precomputed_pos_embed[key]
- value = timm.layers.pos_embed.resample_abs_pos_embed(
- self.pos_embed.data, [h, w],
- )
- self.precomputed_pos_embed[key] = value
- return value
-
- def forward(self, x):
- b, c, h, w = x.shape
- x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
- x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
- b, c, h, w = x.shape
- patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
- pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
- self.encoder.pos_embed.data = pos_embed_data
- feature = self.encoder.forward_features(x)['x_norm_patchtokens']
- return feature
-
-
-class REPALPIPSTrainer(BaseTrainer):
- def __init__(
- self,
- scheduler: BaseScheduler,
- loss_weight_fn:Callable=constant,
- feat_loss_weight: float=0.5,
- lognorm_t=False,
- lpips_weight=1.0,
- encoder_weight_path=None,
- align_layer=8,
- proj_denoiser_dim=256,
- proj_hidden_dim=256,
- proj_encoder_dim=256,
- *args,
- **kwargs
- ):
- super().__init__(*args, **kwargs)
- self.lognorm_t = lognorm_t
- self.scheduler = scheduler
- self.loss_weight_fn = loss_weight_fn
- self.feat_loss_weight = feat_loss_weight
- self.align_layer = align_layer
- self.encoder = DINOv2(encoder_weight_path)
- self.proj_encoder_dim = proj_encoder_dim
- no_grad(self.encoder)
-
- self.lpips_weight = lpips_weight
- self.lpips = _NoTrainLpips(net="vgg")
- self.lpips = self.lpips.to(torch.bfloat16)
- no_grad(self.lpips)
-
- self.proj = nn.Sequential(
- nn.Sequential(
- nn.Linear(proj_denoiser_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_hidden_dim),
- nn.SiLU(),
- nn.Linear(proj_hidden_dim, proj_encoder_dim),
- )
- )
-
- def _impl_trainstep(self, net, ema_net, raw_images, x, y):
- batch_size, c, height, width = x.shape
- if self.lognorm_t:
- base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
- else:
- base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
- t = base_t
-
- noise = torch.randn_like(x)
- alpha = self.scheduler.alpha(t)
- dalpha = self.scheduler.dalpha(t)
- sigma = self.scheduler.sigma(t)
- dsigma = self.scheduler.dsigma(t)
-
- x_t = alpha * x + noise * sigma
- v_t = dalpha * x + dsigma * noise
-
- src_feature = []
- def forward_hook(net, input, output):
- src_feature.append(output)
- if getattr(net, "blocks", None) is not None:
- handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
- else:
- handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
-
- out, _ = net(x_t, t, y)
- src_feature = self.proj(src_feature[0])
- handle.remove()
-
- with torch.no_grad():
- dst_feature = self.encoder(raw_images)
-
- if dst_feature.shape[1] != src_feature.shape[1]:
- dst_length = dst_feature.shape[1]
- rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
- dst_height = (dst_length)**0.5 * (height/width)**0.5
- dst_width = (dst_length)**0.5 * (width/height)**0.5
- dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
- dst_feature = dst_feature.permute(0, 3, 1, 2)
- dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
- dst_feature = dst_feature.permute(0, 2, 3, 1)
- dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
-
- cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
- cos_loss = 1 - cos_sim
-
- weight = self.loss_weight_fn(alpha, sigma)
- fm_loss = weight*(out - v_t)**2
-
- pred_x0 = (x_t + out * sigma)
- target_x0 = x
- # fixbug lpips std
- lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
-
- out = dict(
- lpips_loss=lpips.mean(),
- fm_loss=fm_loss.mean(),
- cos_loss=cos_loss.mean(),
- loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
- )
- return out
-
- def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
- self.proj.state_dict(
- destination=destination,
- prefix=prefix + "proj.",
- keep_vars=keep_vars)
-
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py
deleted file mode 100644
index 3581446..0000000
--- a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py
+++ /dev/null
@@ -1,383 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-from src.utils.model_loader import ModelLoader
-from src.utils.no_grad import no_grad
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiTEncoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
- self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
- return (pos_rope, pos_ape)
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, mask=None):
- B, _, H, W = x.shape
- pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- s = self.s_embedder(x)
- # s = s + pos_ape.to(s.dtype)
- for i in range(self.num_blocks):
- s = self.blocks[i](s, c, pos_rope, mask)
- return None, s
-
-
-class FlattenDiTDecoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
- self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
- def forward(self, x, t, y, s, mask=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
- c = torch.nn.functional.silu(t + y)
- x = torch.cat([x, s], dim=-1)
- x = self.x_embedder(x)
- for i in range(self.num_blocks):
- x = self.blocks[i](x, c, pos, None)
- x = self.final_layer(x, c)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
- stride=self.patch_size)
- return x
-
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- ModelLoader().load(encoder)
- self.encoder = self.encoder.to(torch.bfloat16)
- no_grad(self.encoder)
-
- def forward(self, x, t, y, s=None):
- if s is None:
- with torch.no_grad():
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py
deleted file mode 100644
index 733ce4a..0000000
--- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py
+++ /dev/null
@@ -1,447 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-from src.utils.model_loader import ModelLoader
-from src.utils.no_grad import no_grad
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-class ResBlock(nn.Module):
- def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
- super().__init__()
- self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
- self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
- self.norm1 = nn.GroupNorm(groups, dim)
- self.norm2 = nn.GroupNorm(groups, dim)
- self.embed_proj = nn.Linear(hidden_dim, dim)
-
- def forward(self, x, c):
- c = self.embed_proj(c)[:, :, None, None]
- residual = x
- x = self.conv1(x)
- x = self.norm1(x)
- x = torch.nn.functional.silu(x)
- x = x * c
- x = self.conv2(x)
- x = self.norm2(x)
- x = torch.nn.functional.silu(x)
- return residual + x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiTEncoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
- self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
- return (pos_rope, pos_ape)
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, mask=None, classify_layer=None):
- B, _, H, W = x.shape
- pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- s = self.s_embedder(x)
- # s = s + pos_ape.to(s.dtype)
- classify_feats = []
- for i in range(self.num_blocks):
- s = self.blocks[i](s, c, pos_rope, mask)
- if classify_layer is not None and i < classify_layer:
- classify_feats.append(s)
- if i == classify_layer - 1:
- return _, classify_feats
- return None, s
-
-
-class FlattenDiTDecoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_mid_blocks=18,
- num_res_blocks=[1, 1, 1],
- num_res_channels=[64, 384, 768],
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_mid_blocks = num_mid_blocks
- self.num_res_blocks = num_res_blocks
- self.num_res_channels = num_res_channels
- self.patch_size = 2**(len(num_res_blocks))
-
- self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
- self.t_embedder = TimestepEmbedder(hidden_size)
-
- self.down_res_blocks = nn.ModuleList()
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.down_res_blocks.append(
- nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
- )
- self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = []
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.up_res_blocks.append(
- nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
- )
- self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
-
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
- ])
-
- self.initialize_weights()
- self.precompute_pos = dict()
- self.weight_path = weight_path
- self.load_ema = load_ema
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, s, mask=None):
- B, _, H, W = x.shape
- t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
- y = self.y_embedder(y).view(B, self.hidden_size)
- s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
- c = torch.nn.functional.silu(t + y)
-
- residual = []
- for i, block in enumerate(self.down_res_blocks):
- if isinstance(block, nn.Conv2d):
- residual.append(x)
- x = block(x)
- else:
- x = block(x, c)
-
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = x.view(B, self.hidden_size, -1).transpose(1, 2)
- mid_c = torch.nn.functional.silu(t[:, None, :] + s)
- for i in range(self.num_mid_blocks):
- x = self.blocks[i](x, mid_c, pos, None)
- x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
-
- residual[0] = 0.0
- for i, block in enumerate(self.up_res_blocks):
- if isinstance(block, nn.ConvTranspose2d):
- x = block(x) + residual.pop()
- else:
- x = block(x, c)
- return x
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- ModelLoader().load(encoder)
- self.encoder = self.encoder.to(torch.bfloat16)
- no_grad(self.encoder)
-
- def forward(self, x, t, y, s=None, classify_layer=None):
- if s is None:
- _, s = self.encoder(x, t, y, classify_layer=classify_layer)
- if classify_layer is not None:
- return None, s
- x = self.decoder(x, t, y, s)
- return x, s
-
-class FlattenDiT_jointtraining(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
-
- def forward(self, x, t, y, s=None):
- if s is None:
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py
deleted file mode 100644
index 6e9adbc..0000000
--- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py
+++ /dev/null
@@ -1,448 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-from src.utils.model_loader import ModelLoader
-from src.utils.no_grad import no_grad
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-class ResBlock(nn.Module):
- def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
- super().__init__()
- self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
- self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
- self.norm1 = nn.GroupNorm(groups, dim)
- self.norm2 = nn.GroupNorm(groups, dim)
- self.embed_proj = nn.Linear(hidden_dim, dim)
-
- def forward(self, x, c):
- c = self.embed_proj(c)[:, :, None, None]
- residual = x
- x = self.conv1(x)
- x = self.norm1(x)
- x = torch.nn.functional.silu(x)
- x = x * c
- x = self.conv2(x)
- x = self.norm2(x)
- x = torch.nn.functional.silu(x)
- return residual + x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiTEncoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
- self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
- return (pos_rope, pos_ape)
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, mask=None, classify_layer=None):
- B, _, H, W = x.shape
- pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- s = self.s_embedder(x)
- # s = s + pos_ape.to(s.dtype)
- classify_feats = []
- for i in range(self.num_blocks):
- s = self.blocks[i](s, c, pos_rope, mask)
- if classify_layer is not None and i < classify_layer:
- classify_feats.append(s)
- if i == classify_layer - 1:
- return _, classify_feats
- return None, s
-
-
-class FlattenDiTDecoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_mid_blocks=18,
- num_res_blocks=[1, 1, 1],
- num_res_channels=[64, 384, 768],
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_mid_blocks = num_mid_blocks
- self.num_res_blocks = num_res_blocks
- self.num_res_channels = num_res_channels
- self.patch_size = 2**(len(num_res_blocks))
-
- self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
- self.t_embedder = TimestepEmbedder(hidden_size)
-
- self.down_res_blocks = nn.ModuleList()
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.down_res_blocks.append(
- nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
- )
- self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = []
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.up_res_blocks.append(
- nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
- )
- self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
-
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
- ])
-
- self.initialize_weights()
- self.precompute_pos = dict()
- self.weight_path = weight_path
- self.load_ema = load_ema
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, s, mask=None):
- B, _, H, W = x.shape
- t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
- y = self.y_embedder(y).view(B, self.hidden_size)
- s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
- c = torch.nn.functional.silu(t + y)
-
- residual = []
- for i, block in enumerate(self.down_res_blocks):
- if isinstance(block, nn.Conv2d):
- residual.append(x)
- x = block(x)
- else:
- x = block(x, c)
-
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = x.view(B, self.hidden_size, -1).transpose(1, 2)
- mid_c = torch.nn.functional.silu(t[:, None, :] + s)
- for i in range(self.num_mid_blocks):
- x = self.blocks[i](x, mid_c, pos, None)
- x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
-
- residual[0] = 0.0
- for i, block in enumerate(self.up_res_blocks):
- if isinstance(block, nn.ConvTranspose2d):
- x = block(x) + residual.pop()
- else:
- x = block(x, c)
- return x
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- ModelLoader().load(encoder, "encoder.")
- ModelLoader().load(decoder, "decoder.")
- self.encoder = self.encoder.to(torch.bfloat16)
- no_grad(self.encoder)
-
- def forward(self, x, t, y, s=None, classify_layer=None):
- if s is None:
- _, s = self.encoder(x, t, y, classify_layer=classify_layer)
- if classify_layer is not None:
- return None, s
- x = self.decoder(x, t, y, s)
- return x, s
-
-class FlattenDiT_jointtraining(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
-
- def forward(self, x, t, y, s=None):
- if s is None:
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py
deleted file mode 100644
index 537078a..0000000
--- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py
+++ /dev/null
@@ -1,464 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-from src.utils.model_loader import ModelLoader
-from src.utils.no_grad import no_grad
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-class ResBlock(nn.Module):
- def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
- super().__init__()
- self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
- self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
- self.norm1 = nn.GroupNorm(groups, dim)
- self.norm2 = nn.GroupNorm(groups, dim)
-
- def forward(self, x):
- residual = x
- x = self.conv1(x)
- x = self.norm1(x)
- x = torch.nn.functional.silu(x)
- x = self.conv2(x)
- x = self.norm2(x)
- x = torch.nn.functional.silu(x)
- return residual + x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiTEncoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
- self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
- return (pos_rope, pos_ape)
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, mask=None, classify_layer=None):
- B, _, H, W = x.shape
- pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- s = self.s_embedder(x)
- # s = s + pos_ape.to(s.dtype)
- classify_feats = []
- for i in range(self.num_blocks):
- s = self.blocks[i](s, c, pos_rope, mask)
- if classify_layer is not None and i < classify_layer:
- classify_feats.append(s)
- if i == classify_layer - 1:
- return _, classify_feats
- return None, s
-
-
-class FlattenDiTDecoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_mid_blocks=18,
- num_res_blocks=[1, 1, 1],
- num_res_channels=[64, 384, 768],
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_mid_blocks = num_mid_blocks
- self.num_res_blocks = num_res_blocks
- self.num_res_channels = num_res_channels
- self.patch_size = 2**(len(num_res_blocks))
-
- self.t_embedder = TimestepEmbedder(hidden_size)
-
- self.down_res_blocks = nn.ModuleList()
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.down_res_blocks.append(
- nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
- )
- self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = []
- previous_channel = self.in_channels
- for num, channels in zip(num_res_blocks, num_res_channels):
- self.up_res_blocks.append(
- nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
- )
- self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
- previous_channel = channels
- self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
-
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
- ])
-
- self.initialize_weights()
- self.precompute_pos = dict()
- self.weight_path = weight_path
- self.load_ema = load_ema
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
- # Zero-out adaLN modulation layers in SiT blocks:
- for block in self.blocks:
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- for block in self.down_res_blocks:
- if isinstance(block, ResBlock):
- nn.init.constant_(block.conv1.weight, 0)
- nn.init.constant_(block.conv1.bias, 0)
- nn.init.constant_(block.norm1.weight, 0)
- nn.init.constant_(block.norm2.weight, 0)
- nn.init.constant_(block.conv2.weight, 0)
- nn.init.constant_(block.conv2.bias, 0)
-
- for block in self.up_res_blocks:
- if isinstance(block, ResBlock):
- nn.init.constant_(block.conv1.weight, 0)
- nn.init.constant_(block.conv1.bias, 0)
- nn.init.constant_(block.norm1.weight, 0)
- nn.init.constant_(block.norm2.weight, 0)
- nn.init.constant_(block.conv2.weight, 0)
- nn.init.constant_(block.conv2.bias, 0)
-
- def forward(self, x, t, y, s, mask=None):
- B, _, H, W = x.shape
- t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
- s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
-
- residual = []
- for i, block in enumerate(self.down_res_blocks):
- if isinstance(block, nn.Conv2d):
- residual.append(x)
- x = block(x)
- else:
- x = block(x)
-
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = x.view(B, self.hidden_size, -1).transpose(1, 2)
- mid_c = torch.nn.functional.silu(t[:, None, :] + s)
- for i in range(self.num_mid_blocks):
- x = self.blocks[i](x, mid_c, pos, None)
- x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
-
- residual[0] = 0.0
- for i, block in enumerate(self.up_res_blocks):
- if isinstance(block, nn.ConvTranspose2d):
- x = block(x) + residual.pop()
- else:
- x = block(x)
- return x
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- ModelLoader().load(encoder, "encoder.")
- ModelLoader().load(decoder, "decoder.")
- self.encoder = self.encoder.to(torch.bfloat16)
- no_grad(self.encoder)
-
- def forward(self, x, t, y, s=None, classify_layer=None):
- if s is None:
- _, s = self.encoder(x, t, y, classify_layer=classify_layer)
- if classify_layer is not None:
- return None, s
- x = self.decoder(x, t, y, s)
- return x, s
-
-class FlattenDiT_jointtraining(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
-
- def forward(self, x, t, y, s=None):
- if s is None:
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
diff --git a/src/models/denoiser/condit_dit.py b/src/models/denoiser/condit_dit.py
deleted file mode 100644
index 48d6b0e..0000000
--- a/src/models/denoiser/condit_dit.py
+++ /dev/null
@@ -1,274 +0,0 @@
-import torch
-import torch.nn as nn
-import math
-
-from numba.cuda.cudadrv.devicearray import lru_cache
-from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-from torch.nn.attention import SDPBackend, sdpa_kernel
-
-flex_attention = torch.compile(flex_attention)
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = False,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10000):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = self.norm(x)
- x = modulate(x, shift, scale)
- x = self.linear(x)
- return x
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- self.fc1 = nn.Linear(dim, hidden_dim)
- self.fc2 = nn.Linear(hidden_dim, dim)
- self.act = nn.GELU(approximate="tanh")
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.fc2(x)
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16):
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- x_pos = x_pos.reshape(-1)
- y_pos = y_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1)
- return freqs_cis
-
-
-class Attention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = True,
- qk_norm: bool = False,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = nn.LayerNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- # import pdb; pdb.set_trace()
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q).to(q.dtype)
- k = self.k_norm(k).to(k.dtype)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
- # x = flex_attention(q, k, v, block_mask=mask)
- # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
- x = scaled_dot_product_attention(q, k, v, mask)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class DiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
- self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False)
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class ConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- out_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
- self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2)
- self.num_cond_blocks = num_cond_blocks
-
-
- self.weight_path = weight_path
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
-
- @lru_cache
- def fetch_pos(self, height, width, device):
- pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...]
- return pos
-
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
- def forward(self, x, t, y, s=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H, W, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
-
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
-
- if s is None:
- # semantic encoder
- s = self.s_embedder(x) + pos
- c = nn.functional.silu(t + y)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
diff --git a/src/models/denoiser/flatten_condit_dit_fixt.py b/src/models/denoiser/decouple_improved_dit.py
similarity index 95%
rename from src/models/denoiser/flatten_condit_dit_fixt.py
rename to src/models/denoiser/decouple_improved_dit.py
index 15557f3..f20115c 100644
--- a/src/models/denoiser/flatten_condit_dit_fixt.py
+++ b/src/models/denoiser/decouple_improved_dit.py
@@ -194,7 +194,7 @@ class RAttention(nn.Module):
-class FlattenDiTBlock(nn.Module):
+class DDTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
@@ -213,14 +213,14 @@ class FlattenDiTBlock(nn.Module):
return x
-class FlattenConDiT(nn.Module):
+class DDT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
- num_cond_blocks=4,
+ num_encoder_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
@@ -236,7 +236,7 @@ class FlattenConDiT(nn.Module):
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
+ self.num_encoder_blocks = num_encoder_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
@@ -249,7 +249,7 @@ class FlattenConDiT(nn.Module):
self.load_ema = load_ema
self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
@@ -280,11 +280,6 @@ class FlattenConDiT(nn.Module):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
@@ -301,12 +296,12 @@ class FlattenConDiT(nn.Module):
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
+ for i in range(self.num_encoder_blocks):
s = self.blocks[i](s, c, pos, mask)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
+ for i in range(self.num_encoder_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
diff --git a/src/models/denoiser/flatten_condit_catdit_fixt.py b/src/models/denoiser/flatten_condit_catdit_fixt.py
deleted file mode 100644
index 22a0fd5..0000000
--- a/src/models/denoiser/flatten_condit_catdit_fixt.py
+++ /dev/null
@@ -1,314 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None, mask=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, mask)
- # s = nn.functional.silu(t + s)
- s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
- x = torch.cat((x, s), dim=-1)
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, c, pos, None)
- x = self.final_layer(x, c)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_conv_fixt.py b/src/models/denoiser/flatten_condit_conv_fixt.py
deleted file mode 100644
index 219db4c..0000000
--- a/src/models/denoiser/flatten_condit_conv_fixt.py
+++ /dev/null
@@ -1,340 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-class FlattenConvBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3):
- super().__init__()
- self.hidden_size = hidden_size
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
- attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
- attn_x = self.attn(attn_x)
- attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
- x = x + gate_msa * attn_x
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- kernel_size=3,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([])
- for i in range(self.num_cond_blocks):
- self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
- for i in range(self.num_blocks-self.num_cond_blocks):
- self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size))
-
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, None)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_convnext_fixt.py b/src/models/denoiser/flatten_condit_convnext_fixt.py
deleted file mode 100644
index cf9c214..0000000
--- a/src/models/denoiser/flatten_condit_convnext_fixt.py
+++ /dev/null
@@ -1,339 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-class FlattenConvBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.hidden_size = hidden_size
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
- attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
- attn_x = self.attn(attn_x)
- attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
- x = x + gate_msa * attn_x
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([])
- for i in range(self.num_cond_blocks):
- self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
- for i in range(self.num_blocks-self.num_cond_blocks):
- self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups))
-
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, None)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_dit_norm_fixt.py b/src/models/denoiser/flatten_condit_dit_norm_fixt.py
deleted file mode 100644
index 28034e3..0000000
--- a/src/models/denoiser/flatten_condit_dit_norm_fixt.py
+++ /dev/null
@@ -1,314 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None, mask=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, mask)
- s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
deleted file mode 100644
index 9a5e4fd..0000000
--- a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
+++ /dev/null
@@ -1,429 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-from src.utils.model_loader import ModelLoader
-from src.utils.no_grad import no_grad
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiTEncoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
- self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
- return (pos_rope, pos_ape)
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- def forward(self, x, t, y, mask=None):
- B, _, H, W = x.shape
- pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- s = self.s_embedder(x)
- # s = s + pos_ape.to(s.dtype)
- for i in range(self.num_blocks):
- s = self.blocks[i](s, c, pos_rope, mask)
- return None, s
-
-
-class FlattenDiTDecoder(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)]
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
- def forward(self, x, t, y, s, mask=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- # s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
- s = torch.nn.functional.silu(t + s)
- x = self.x_embedder(x)
- for i in range(self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
- stride=self.patch_size)
- return x
-
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- joint_training=False,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
-
- ModelLoader().load(encoder)
- if not joint_training:
- self.encoder = self.encoder.to(torch.bfloat16)
- no_grad(self.encoder)
-
- self.joint_training = joint_training
-
- def forward(self, x, t, y, s=None):
- if s is None:
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
-class FlattenDiTScalingEncoder(nn.Module):
- def __init__(
- self,
- encoder:FlattenDiTEncoder,
- decoder:FlattenDiTDecoder,
- ):
- super().__init__()
- self.encoder = encoder
- self.decoder = decoder
- no_grad(self.decoder)
-
- if self.encoder.weight_path:
- weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu'))
- if self.encoder.load_ema:
- prefix = "ema_denoiser."
- else:
- prefix = "denoiser."
- for k, v in self.encoder.state_dict().items():
- try:
- v.copy_(weight["state_dict"][prefix+k])
- except:
- print(f"Failed to copy {prefix+k} to denoiser weight")
-
- if self.decoder.weight_path:
- weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu'))
- if self.decoder.load_ema:
- prefix = "ema_denoiser."
- else:
- prefix = "denoiser."
- for k, v in self.decoder.state_dict().items():
- if "blocks." in k:
- blockid = int(k.split("blocks.")[-1][0])
- k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}")
- try:
- v.copy_(weight["state_dict"][prefix+k])
- except:
- print(f"Failed to copy {prefix+k} to denoiser weight")
- self.decoder = decoder.to(torch.bfloat16)
-
- def forward(self, x, t, y, s=None):
- if s is None:
- _, s = self.encoder(x, t, y)
- x = self.decoder(x, t, y, s)
- return x, s
-
diff --git a/src/models/denoiser/flatten_condit_mlp_fixt.py b/src/models/denoiser/flatten_condit_mlp_fixt.py
deleted file mode 100644
index 40735e4..0000000
--- a/src/models/denoiser/flatten_condit_mlp_fixt.py
+++ /dev/null
@@ -1,334 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-class FlattenMLPBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = FeedForward(hidden_size, mlp_hidden_dim)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([])
- for i in range(self.num_cond_blocks):
- self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
- for i in range(self.num_blocks-self.num_cond_blocks):
- self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups))
-
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.s_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, None)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
deleted file mode 100644
index bcf3315..0000000
--- a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
+++ /dev/null
@@ -1,321 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.s_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.s_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None, mask=None):
- B, _, H, W = x.shape
- pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
-
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device)
- s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2)
- s = self.s_embedder(s)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos_s, mask)
- s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size)
- s = torch.permute(s, (0, 3, 1, 2))
- s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False)
- s = torch.permute(s, (0, 2, 3, 1))
- s = s.view(B, -1, self.hidden_size)
- s = nn.functional.silu(t + s)
-
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos_x, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_dit_fixt_xvout.py b/src/models/denoiser/flatten_dit_fixt_xvout.py
deleted file mode 100644
index 4df3393..0000000
--- a/src/models/denoiser/flatten_dit_fixt_xvout.py
+++ /dev/null
@@ -1,311 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return (self.weight * hidden_states).to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device, dtype):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device, dtype)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
- def forward(self, x, t, y, masks=None):
- if masks is None:
- masks = [None, ]*self.num_blocks
- if isinstance(masks, torch.Tensor):
- masks = masks.unbind(0)
- if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
- masks = masks + [None]*(self.num_blocks-len(masks))
-
- B, _, H, W = x.shape
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- x = self.x_embedder(x)
- pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
- B, L, C = x.shape
- t = self.t_embedder(t.view(-1)).view(B, -1, C)
- y = self.y_embedder(y).view(B, 1, C)
- condition = nn.functional.silu(t + y)
- for i, block in enumerate(self.blocks):
- x = block(x, condition, pos, masks[i])
- x = self.final_layer(x, condition)
- x0, v = x.chunk(2, dim=-1)
- x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- if self.training:
- return v, x0
- else:
- return v
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
deleted file mode 100644
index 4e570b0..0000000
--- a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
+++ /dev/null
@@ -1,308 +0,0 @@
-import functools
-from typing import Tuple
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from torch.nn.modules.module import T
-
-# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
-from torch.nn.functional import scaled_dot_product_attention
-
-
-def modulate(x, shift, scale):
- return x * (1 + scale) + shift
-
-class Embed(nn.Module):
- def __init__(
- self,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class TimestepEmbedder(nn.Module):
-
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
-
- @staticmethod
- def timestep_embedding(t, dim, max_period=10):
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
- )
- args = t[..., None].float() * freqs[None, ...]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-class LabelEmbedder(nn.Module):
- def __init__(self, num_classes, hidden_size):
- super().__init__()
- self.embedding_table = nn.Embedding(num_classes, hidden_size)
- self.num_classes = num_classes
-
- def forward(self, labels,):
- embeddings = self.embedding_table(labels)
- return embeddings
-
-class FinalLayer(nn.Module):
- def __init__(self, hidden_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 2*hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- return x
-
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
- # assert H * H == end
- # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
- x_pos = torch.linspace(0, scale, width)
- y_pos = torch.linspace(0, scale, height)
- y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
- y_pos = y_pos.reshape(-1)
- x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
- x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
- y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
- return freqs_cis
-
-
-def apply_rotary_emb(
- xq: torch.Tensor,
- xk: torch.Tensor,
- freqs_cis: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- freqs_cis = freqs_cis[None, :, None, :]
- # xq : B N H Hc
- xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
- xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
- xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
- xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
-
-class RAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int = 8,
- qkv_bias: bool = False,
- qk_norm: bool = True,
- attn_drop: float = 0.,
- proj_drop: float = 0.,
- norm_layer: nn.Module = RMSNorm,
- ) -> None:
- super().__init__()
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
-
- self.dim = dim
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
- self.scale = self.head_dim ** -0.5
-
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
-
- def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
- B, N, C = x.shape
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
- q = self.q_norm(q)
- k = self.k_norm(k)
- q, k = apply_rotary_emb(q, k, freqs_cis=pos)
- q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
- k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
- v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
-
- x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
-
- x = x.transpose(1, 2).reshape(B, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
-
-class FlattenDiTBlock(nn.Module):
- def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c, pos, mask=None):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
- x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
- x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-class FlattenConDiT(nn.Module):
- def __init__(
- self,
- in_channels=4,
- num_groups=12,
- hidden_size=1152,
- num_blocks=18,
- num_cond_blocks=4,
- patch_size=2,
- num_classes=1000,
- learn_sigma=True,
- deep_supervision=0,
- weight_path=None,
- load_ema=False,
- ):
- super().__init__()
- self.deep_supervision = deep_supervision
- self.learn_sigma = learn_sigma
- self.in_channels = in_channels
- self.out_channels = in_channels
- self.hidden_size = hidden_size
- self.num_groups = num_groups
- self.num_blocks = num_blocks
- self.num_cond_blocks = num_cond_blocks
- self.patch_size = patch_size
- self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
- self.t_embedder = TimestepEmbedder(hidden_size)
- self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
-
- self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
-
- self.weight_path = weight_path
-
- self.load_ema = load_ema
- self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
- ])
- self.initialize_weights()
- self.precompute_pos = dict()
-
- def fetch_pos(self, height, width, device):
- if (height, width) in self.precompute_pos:
- return self.precompute_pos[(height, width)].to(device)
- else:
- pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
- self.precompute_pos[(height, width)] = pos
- return pos
-
- def initialize_weights(self):
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
- w = self.x_embedder.proj.weight.data
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
- nn.init.constant_(self.x_embedder.proj.bias, 0)
-
-
- # Initialize label embedding table:
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
-
- # Initialize timestep embedding MLP:
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
-
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
- # Zero-out output layers:
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
- nn.init.constant_(self.final_layer.linear.weight, 0)
- nn.init.constant_(self.final_layer.linear.bias, 0)
-
-
- def forward(self, x, t, y, s=None, mask=None):
- B, _, H, W = x.shape
- pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
- x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
- t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
- y = self.y_embedder(y).view(B, 1, self.hidden_size)
- c = nn.functional.silu(t + y)
- if s is None:
- s = self.x_embedder(x)
- for i in range(self.num_cond_blocks):
- s = self.blocks[i](s, c, pos, mask)
- s = nn.functional.silu(t + s)
-
- x = self.x_embedder(x)
- for i in range(self.num_cond_blocks, self.num_blocks):
- x = self.blocks[i](x, s, pos, None)
- x = self.final_layer(x, s)
- x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
- return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flowdcn.py b/src/models/denoiser/flowdcn.py
deleted file mode 100644
index 92e2237..0000000
--- a/src/models/denoiser/flowdcn.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import torch
-import torch.nn as nn
-import math
-
-from torch.nn.init import zeros_
-from src.models.denoiser.base_model import BaseModel
-from src.ops.triton_kernels.function import DCNFunction
-
-def modulate(x, shift, scale):
- return x * (1 + scale[:, None, None]) + shift[:, None, None]
-
-class PatchEmbed(nn.Module):
- def __init__(
- self,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer = None,
- bias: bool = True,
- ):
- super().__init__()
- self.patch_size = patch_size
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- def forward(self, x):
- x = self.proj(x)
- x = self.norm(x)
- return x
-
-class RMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-class FeedForward(nn.Module):
- def __init__(
- self,
- dim: int,
- hidden_dim: int,
- ):
- super().__init__()
- hidden_dim = int(2 * hidden_dim / 3)
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
- def forward(self, x):
- b, h, w, c = x.shape
- x = x.view(b, h*w, c)
- x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
- x = x.view(b, h, w, c)
- return x
-
-
-class MultiScaleDCN(nn.Module):
- def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
- super().__init__()
- self.in_channels = in_channels
- self.groups = groups
- self.channels = channels
- self.kernels = kernels
- self.v = nn.Linear(in_channels, groups * channels, bias=True)
- self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
- self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
- self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
- self.out = nn.Linear(groups * channels, in_channels)
- self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
- self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
- self.max_scale = 6
- self._init_weights()
- def _init_weights(self):
- zeros_(self.qk_deformables.weight.data)
- zeros_(self.qk_scales.weight.data)
- zeros_(self.qk_deformables.bias.data)
- zeros_(self.qk_weights.weight.data)
- zeros_(self.v.bias.data)
- zeros_(self.out.bias.data)
- num_prior = int(self.kernels ** 0.5)
- dx = torch.linspace(-1, 1, num_prior, device="cuda")
- dy = torch.linspace(-1, 1, num_prior, device="cuda")
- dxy = torch.meshgrid([dx, dy], indexing="xy")
- dxy = torch.stack(dxy, dim=-1)
- dxy = dxy.view(-1, 2)
- self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
- for i in range(self.groups):
- scale = (i+1)/self.groups - 0.0001
- inv_scale = math.log((scale)/(1-scale))
- self.deformables_scale.data[..., i, :, :] = inv_scale
- def forward(self, x):
- B, H, W, _ = x.shape
- v = self.v(x).view(B, H, W, self.groups, self.channels)
- deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
- scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
- deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
- weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
- out = DCNFunction.apply(v, deformables, weights)
- out = out.view(B, H, W, -1)
- out = self.out(out)
- return out
-
-class FlowDCNBlock(nn.Module):
- def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True):
- super().__init__()
- self.norm1 = RMSNorm(hidden_size, eps=1e-6)
- self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass)
- self.norm2 = RMSNorm(hidden_size, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(),
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
- x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
- x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
- return x
-
-
-
-
-class FlowDCN(BaseModel):
- def __init__(self, deformable_biass=True, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.blocks = nn.ModuleList([
- FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks)
- ])
- self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True)
- self.initialize_weights()
-
- def forward(self, x, t, y):
- batch_size, _, height, width = x.shape[0]
- x = self.x_embedder(x) # (N, D, h, w)
- x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1)
- t = self.t_embedder(t) # (N, D)
- y = self.y_embedder(y, self.training) # (N, D)
- c = t + y # (N, D)
- B, L, C = x.shape
- x = x.view(B, height//self.patch_size, width//self.patch_size, C)
- for block in self.blocks:
- x = block(x, c) # (N, T, D)
- x = x.view(B, L, C)
- x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
- x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size)
- if self.learn_sigma:
- x, _ = torch.split(x, self.out_channels // 2, dim=1)
- return x
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_dit_fixt.py b/src/models/denoiser/improved_dit.py
similarity index 96%
rename from src/models/denoiser/flatten_dit_fixt.py
rename to src/models/denoiser/improved_dit.py
index 9412d6e..99e2f5a 100644
--- a/src/models/denoiser/flatten_dit_fixt.py
+++ b/src/models/denoiser/improved_dit.py
@@ -194,7 +194,7 @@ class RAttention(nn.Module):
-class FlattenDiTBlock(nn.Module):
+class DiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
@@ -213,7 +213,7 @@ class FlattenDiTBlock(nn.Module):
return x
-class FlattenDiT(nn.Module):
+class DiT(nn.Module):
def __init__(
self,
in_channels=4,
@@ -246,7 +246,7 @@ class FlattenDiT(nn.Module):
self.load_ema = load_ema
self.blocks = nn.ModuleList([
- FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
@@ -272,11 +272,6 @@ class FlattenDiT(nn.Module):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
- # # Zero-out adaLN modulation layers in SiT blocks:
- # for block in self.blocks:
- # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
- # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
-
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
diff --git a/src/ops/cuda_kernels/backward.cu b/src/ops/cuda_kernels/backward.cu
deleted file mode 100644
index 2e85d86..0000000
--- a/src/ops/cuda_kernels/backward.cu
+++ /dev/null
@@ -1,346 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-
-namespace cg = cooperative_groups;
-
-template
-__device__ __always_inline int toInt(scalar_t val);
-
-template<>
-__device__ __always_inline int toInt(float val){
- return static_cast(val);
-}
-template<>
-__device__ __always_inline int toInt(half val){
- return __half2int_rz(val);
-}
-
-template
-__device__ __always_inline scalar_t fromInt(int val);
-
-template<>
-__device__ __always_inline float fromInt(int val){
- return static_cast(val);
-}
-
-template<>
-__device__ __always_inline half fromInt(int val){
- return __int2half_rz(val);
-}
-
-template
-__device__ __always_inline scalar_t constVal(float val);
-
-template<>
-__device__ __always_inline float constVal(float val) {
- return (float)val;
-}
-
-template<>
-__device__ __always_inline half constVal(float val) {
- return __float2half(val); // Using float to half conversion
-}
-template<>
-__device__ __always_inline nv_bfloat16 constVal(float val){
- return __float2bfloat16(val);
-}
-
-
-
-
-
-// B, H, W, C, BLOCK_DIM must be multiple of C
-template
-__global__ void dcn_backward_pipeline_kernel(
- const int H,
- const int W,
- const int G,
- const int K,
- const int C,
- scalar_t* ptr_values,
- scalar_t* ptr_deformables,
- scalar_t* ptr_weights,
- scalar_t* ptr_grad_out,
- scalar_t* ptr_grad_values,
- scalar_t* ptr_grad_deformables,
- scalar_t* ptr_grad_weights
-) {
- auto block = cg::this_thread_block();
- auto self_thread = cg::this_thread();
- auto tile_threads = cg::tiled_partition(block);
- int local_thread_id = block.thread_rank();
- int local_tile_id = tile_threads.meta_group_rank();
- int num_local_tiles = tile_threads.meta_group_size();
- int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id;
-
- extern __shared__ int shm[];
- auto GradBuffer = reinterpret_cast(shm);
- scalar_t* Buffer = reinterpret_cast(shm) + num_local_tiles*C;
- if(global_tile_id >= H*W*G) return;
-
- int bid = block.group_index().y;
- int gid = global_tile_id % G;
- int wid = global_tile_id / G % W;
- int hid = global_tile_id / G / W;
- int globale_offset = bid*H*W*G*C + global_tile_id*C;
- cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C);
-
- int shared_offset[pipeline_stages];
- for (int s = 0; s < pipeline_stages; ++s) {
- shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4);
- }
-
- auto pipeline = cuda::make_pipeline();
- const int num_tiles_per_thread = C/TILE_C/TILE_THREADS;
-
- for(int k=0; k(wid);
- y = ptr_deformables[offset*2 + 1] + fromInt(hid);
-// x = fromInt(wid);
-// y = fromInt(hid);
- weight = ptr_weights[offset];
- }
- tile_threads.sync();
- x = tile_threads.shfl(x, 0);
- y = tile_threads.shfl(y, 0);
- weight = tile_threads.shfl(weight, 0);
-
- int floor_x = toInt(x);
- int floor_y = toInt(y);
- int ceil_x = floor_x + 1;
- int ceil_y = floor_y + 1;
-
-
- scalar_t dodx = constVal(0.0f);
- scalar_t dody = constVal(0.0f);
- scalar_t dodw = constVal(0.0f);
-
- int start_c = tile_threads.thread_rank() * (C / TILE_THREADS);
-
- bool tl_flag = (floor_x >=0) and (floor_x =0) and (floor_y=0) and (ceil_x =0) and (floor_y=0) and (floor_x =0) and (ceil_y=0) and (ceil_x =0) and (ceil_y(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
- dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
- dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
- dodw = dodw + (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
- dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1];
- dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
- {
- vec2_t vtl_di;
- vtl_di.x = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
- vtl_di.y = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1];
- atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di);
- }
- }
-
-
- if(tr_flag){
- // tr
- dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
- dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
- dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
- dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1];
- dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1];
- dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1];
- {
- vec2_t vtr_di;
- vtr_di.x = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
- vtr_di.y = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1];
- atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di);
- }
- }
-
- if(bl_flag){
- // bl
- dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
- dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
- dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
- dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
- dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
- dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
- {
- vec2_t vbl_di;
- vbl_di.x = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j];
- vbl_di.y = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1];
- atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di);
- }
- }
-
-
- if(br_flag){
- // tr
- dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
- dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
- dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
- dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
- dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
- dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
- {
- vec2_t vbr_di;
- vbr_di.x = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j];
- vbr_di.y = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1];
- atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di);
- }
- }
- }
- pipeline.consumer_release();
- }
- for (int i = TILE_THREADS>>1; i > 0; i/=2) {
- dodx = dodx + tile_threads.shfl_down(dodx, i);
- dody = dody + tile_threads.shfl_down(dody, i);
- dodw = dodw + tile_threads.shfl_down(dodw, i);
- }
- if (tile_threads.thread_rank() == 0) {
- cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline);
- cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline);
- cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline);
- }
- }
-}
-
-
-using namespace torch;
-template
-void backward(const int B,
- const int H,
- const int W,
- const int G,
- const int K,
- const int C,
- torch::Tensor values,
- torch::Tensor deformables,
- torch::Tensor weights,
- torch::Tensor grad_out,
- torch::Tensor grad_values,
- torch::Tensor grad_deformables,
- torch::Tensor grad_weights
-) {
- int num_local_tiles =(THREADS/TILE_THREADS);
- int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles;
- dim3 launch_threads_per_block(THREADS);
- dim3 launch_blocks(num_global_tiles, B);
-
- int deformable_shm_size = 0;
- int grad_out_shm_size = num_local_tiles*C;
- int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS);
-
- int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size;
-// printf("shm_size: %d\n", shm_size/512);
-// printf("pipeline_size: %d\n", pipeline_shm_size/512);
-// printf("grad_out_size: %d\n", grad_out_shm_size/512);
-
-
- switch (values.type().scalarType()) {
- case at::ScalarType::Half:
- return dcn_backward_pipeline_kernel<<>>(
- H, W, G, K, C,
- reinterpret_cast(values.data_ptr()),
- reinterpret_cast(deformables.data_ptr()),
- reinterpret_cast(weights.data_ptr()),
- reinterpret_cast(grad_out.data_ptr()),
- reinterpret_cast(grad_values.data_ptr()),
- reinterpret_cast(grad_deformables.data_ptr()),
- reinterpret_cast(grad_weights.data_ptr())
- );
-// case at::ScalarType::BFloat16:
-// return dcn_backward_pipeline_kernel<<>>(
-// H, W, G, K, C,
-// reinterpret_cast(values.data_ptr()),
-// reinterpret_cast(deformables.data_ptr()),
-// reinterpret_cast(weights.data_ptr()),
-// reinterpret_cast(grad_out.data_ptr()),
-// reinterpret_cast(grad_values.data_ptr()),
-// reinterpret_cast(grad_deformables.data_ptr()),
-// reinterpret_cast(grad_weights.data_ptr())
-// );
- default:
- printf("running error");
- }
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, "");
- m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, "");
- m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, "");
- m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, "");
- m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, "");
- m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, "");
- m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, "");
- m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, "");
- m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, "");
- m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, "");
- m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, "");
- m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, "");
- m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, "");
- m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, "");
- m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, "");
-// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, "");
-// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, "");
-// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, "");
-
- m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, "");
- m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, "");
- m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, "");
- m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, "");
-}
diff --git a/src/ops/cuda_kernels/bak_forward.cu b/src/ops/cuda_kernels/bak_forward.cu
deleted file mode 100644
index 00569f8..0000000
--- a/src/ops/cuda_kernels/bak_forward.cu
+++ /dev/null
@@ -1,289 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-
-template
-__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
- #pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
- #pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
-#pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
-#pragma unroll
- for(int i=0; i
-__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
- int work_id = threadIdx.x;
- int bid = blockIdx.z;
- int gid = blockIdx.y;
- int G = gridDim.y;
- int c_blockid = blockIdx.x;
- int work_load = (H*W/blockDim.x);
-
- __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
- // __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
- math_t register_bufferA[BLOCK_DIM] = {0};
- int base_c = c_blockid*BLOCK_DIM;
-
- int num_transfers = BLOCK_DIM;
-#pragma unroll
- for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
- // loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
-#pragma unroll
- for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
- }
- // load top right
- math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
- if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
- }
-
- // load bottom left
- math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
- if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
- }
- // load bottom right
- math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
- if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
- }
-
- }
- // loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
- int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
-#pragma unroll
- for(int j=0; j
-__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
- int work_id = threadIdx.x;
- int bid = blockIdx.z;
- int gid = blockIdx.y;
- int G = gridDim.y;
- int c_blockid = blockIdx.x;
- int work_load = (H*W/blockDim.x);
-
- __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
- __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
- math_t register_bufferA[BLOCK_DIM] = {0};
- int base_c = c_blockid*BLOCK_DIM;
-
- int num_transfers = BLOCK_DIM/transfer_length;
-#pragma unroll
- for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
- loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
-#pragma unroll
- for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
- }
- // load top right
- math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
- if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
- }
-
- // load bottom left
- math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
- if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
- }
- // load bottom right
- math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
- if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
- }
-
- }
- loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
-
- }
-
- __syncthreads();
-
-#pragma unroll
- for(int i=0; i
-void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
-
- int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
- dim3 launch_threads_per_block(THREADS);
- dim3 launch_blocks(NUM_C_BLOCK, G, B);
-
- switch (value.type().scalarType()) {
- case at::ScalarType::Half:
- return dcn_forward_kernel_16<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- case at::ScalarType::BFloat16:
- return dcn_forward_kernel_16<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- case at::ScalarType::Float:
- return dcn_forward_kernel<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- default:
- printf("running error");
- }
-}
-
-
-// PyBind11 bindings
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
-//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
-}
diff --git a/src/ops/cuda_kernels/forward.cu b/src/ops/cuda_kernels/forward.cu
deleted file mode 100644
index ac18308..0000000
--- a/src/ops/cuda_kernels/forward.cu
+++ /dev/null
@@ -1,309 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-
-template
-__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
- #pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
- #pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
-#pragma unroll
- for(int i=0; i
-__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
-#pragma unroll
- for(int i=0; i
-__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
- int work_id = threadIdx.x;
- int bid = blockIdx.z;
- int gid = blockIdx.y;
- int G = gridDim.y;
- int c_blockid = blockIdx.x;
- int work_load = (H*W/blockDim.x);
-
- extern __shared__ int shm[];
- math_t* math_buffer = reinterpret_cast(shm);
-
- math_t register_bufferA[BLOCK_DIM] = {0};
- int base_c = c_blockid*BLOCK_DIM;
-
-#pragma unroll
- for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
-#pragma unroll
- for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
- }
- // load top right
- math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
- if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
- }
-
- // load bottom left
- math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
- if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
- }
- // load bottom right
- math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
- if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
- }
-
- }
- int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
-#pragma unroll
- for(int j=0; j
-__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
- int work_id = threadIdx.x;
- int bid = blockIdx.z;
- int gid = blockIdx.y;
- int G = gridDim.y;
- int c_blockid = blockIdx.x;
- int work_load = (H*W/blockDim.x);
-
- extern __shared__ int shm[];
- math_t* math_buffer = reinterpret_cast(shm);
- scalar_t* io_buffer = reinterpret_cast(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t);
- math_t register_bufferA[BLOCK_DIM] = {0};
- int base_c = c_blockid*BLOCK_DIM;
-
- int num_transfers = BLOCK_DIM/transfer_length;
-#pragma unroll
- for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
- loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
-#pragma unroll
- for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
- }
- // load top right
- math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
- if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
- }
-
- // load bottom left
- math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
- if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
- }
- // load bottom right
- math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
- if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
- loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
- }
-
- }
- loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
-
- }
-
- __syncthreads();
-
-#pragma unroll
- for(int i=0; i
-void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
-
- int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
- dim3 launch_threads_per_block(THREADS);
- dim3 launch_blocks(NUM_C_BLOCK, G, B);
- int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half);
- switch (value.type().scalarType()) {
- case at::ScalarType::Half:
- return dcn_forward_kernel_register<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- case at::ScalarType::Float:
- return dcn_forward_kernel_register<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- default:
- printf("running error");
- }
-}
-
-template
-void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
-
- int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
- dim3 launch_threads_per_block(THREADS);
- dim3 launch_blocks(NUM_C_BLOCK, G, B);
- int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half);
- switch (value.type().scalarType()) {
- case at::ScalarType::Half:
- return dcn_forward_kernel_pipeline<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- case at::ScalarType::BFloat16:
- return dcn_forward_kernel_pipeline<<>>(
- H, W, C,
- value.data_ptr(),
- deformables.data_ptr(),
- weights.data_ptr(),
- out.data_ptr());
- default:
- printf("running error");
- }
-}
-
-// PyBind11 bindings
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
-//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward");
-m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward");
-m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward");
-m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward");
-m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
-// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
-}
diff --git a/src/ops/cuda_kernels/forward.py b/src/ops/cuda_kernels/forward.py
deleted file mode 100644
index 4ea9c5e..0000000
--- a/src/ops/cuda_kernels/forward.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import triton
-import triton.language as tl
-
-@triton.autotune(
- configs=[
- # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1),
- triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
- triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
- ],
- key=['B', 'H', 'W', 'G', 'C', 'K'],
-)
-@triton.jit
-def forward_kernel(
- B: tl.constexpr,
- H: tl.constexpr, # image_size_h
- W: tl.constexpr, # image_size_w
- G: tl.constexpr, # num_channels_per_group
- C: tl.constexpr, # num_groups
- K: tl.constexpr, # kernel size
- input_ptr, # input features [B, H, W, G, C]
- deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
- weights_ptr, # weights [B, H, W, G, K]
- out_ptr, # out [B, H, W, G, C]
- BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
-):
- pid = tl.program_id(0)
- wid = pid % W
- hid = pid // W % H
- gid = pid // (W * H) % G
- bid = pid // (W * H * G)
-
- id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
- common_offset = bid*H*W*G + hid*W*G + wid*G + gid
- batch_base = bid * H * W * G * C
-
- for block_base in tl.static_range(0, C, BLOCK_SIZE):
- buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
- block_offset = tl.arange(0, BLOCK_SIZE) + block_base
- block_mask = (block_offset < C) & id_mask
- for k in tl.static_range(K):
- deformable_offset = (common_offset * K + k) * 2
-
- x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
- y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
-
- floor_x = x.to(tl.int32)
- floor_y = y.to(tl.int32)
- ceil_x = floor_x + 1
- ceil_y = floor_y + 1
-
- # load top left
- tl_weight = (ceil_x - x) * (ceil_y - y)
- tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
- tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
-
- # load top right
- tr_weight = (x - floor_x) * (ceil_y - y)
- tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
- tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
- # load bottom left
- bl_weight = (ceil_x - x) * (y - floor_y)
- bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
- bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
- # load bottom right
- br_weight = (x - floor_x) * (y - floor_y)
- br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
- br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
-
- # load dynamic weight and mask
- weights_offset = common_offset*K + k
- weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
-
-
-
- tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
- tl_block_input = tl_block_input * tl_weight
-
- # load top right
- tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
- tr_block_input = tr_block_input * tr_weight
- # load bottom left
- bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
- bl_block_input = bl_block_input * bl_weight
- # load bottom right
- br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
- br_block_input = br_block_input * br_weight
-
- # sampled
- sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
-
- weighted_sampled_input = sampled_input * weight
- buffer = buffer + weighted_sampled_input
- # store to out_ptr
- tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
-
diff --git a/src/ops/cuda_kernels/function.py b/src/ops/cuda_kernels/function.py
deleted file mode 100644
index 9d4bfad..0000000
--- a/src/ops/cuda_kernels/function.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import time
-import dcn_cuda_backward
-import dcn_cuda_forward
-
-import math
-import torch
-from typing import Any
-from torch.autograd import Function
-from torch.autograd.function import once_differentiable
-from torch.cuda.amp import custom_fwd, custom_bwd
-from .forward import forward_kernel
-
-
-class DCNFunction(Function):
- BP_FUNCS = [
- dcn_cuda_backward.backward_p1_c2_tile16_thread128,
- dcn_cuda_backward.backward_p1_c4_tile16_thread128,
- dcn_cuda_backward.backward_p2_c2_tile16_thread128,
- dcn_cuda_backward.backward_p1_c2_tile16_thread256,
- dcn_cuda_backward.backward_p1_c4_tile16_thread256,
- dcn_cuda_backward.backward_p2_c2_tile16_thread256,
- dcn_cuda_backward.backward_p1_c2_tile16_thread384,
- dcn_cuda_backward.backward_p1_c4_tile16_thread384,
- dcn_cuda_backward.backward_p2_c2_tile16_thread384,
- dcn_cuda_backward.backward_p1_c2_tile16_thread512,
- dcn_cuda_backward.backward_p1_c4_tile16_thread512,
- dcn_cuda_backward.backward_p2_c2_tile16_thread512,
- dcn_cuda_backward.backward_p1_c2_tile16_thread768,
- dcn_cuda_backward.backward_p1_c4_tile16_thread768,
- dcn_cuda_backward.backward_p2_c2_tile16_thread768,
- dcn_cuda_backward.backward_p1_c2_tile32_thread128,
- dcn_cuda_backward.backward_p1_c2_tile32_thread256,
- dcn_cuda_backward.backward_p1_c2_tile32_thread384,
- dcn_cuda_backward.backward_p1_c2_tile32_thread512,
- ]
- FW_FUNCS = [
- dcn_cuda_forward.dcn_forward_l256_c4,
- dcn_cuda_forward.dcn_forward_l256_c8,
- dcn_cuda_forward.dcn_forward_l256_c16,
- ]
- BP_TABLES = dict()
- FW_TABLES = dict()
-
-
- @staticmethod
- @custom_fwd(cast_inputs=torch.float16)
- def forward(ctx: Any, values, deformables, weights) -> Any:
- B, H, W, G, C = values.shape
- func = DCNFunction.find_fw_funcs(values, deformables, weights)
- out = torch.zeros_like(values)
- func(B, G, C, H, W, values, deformables, weights, out)
- return out
-
- @staticmethod
- def find_fw_funcs(values, deformables, weights):
- B, H, W, G, C = values.shape
- B, H, W, G, K = weights.shape
- hash_value = 10000 * B + 100 * H + W + 1000 * G
- if hash_value in DCNFunction.FW_TABLES.keys():
- return DCNFunction.FW_TABLES[hash_value]
- print("missing")
- candicate_func = None
- min_t = 999.0
- outs = torch.zeros_like(values)
- for func in DCNFunction.FW_FUNCS:
- t = []
- for i in range(100):
- torch.cuda.synchronize()
- start_t = time.time()
- func(B, G, C, H, W, values, deformables, weights, outs)
- torch.cuda.synchronize()
- t.append(time.time() - start_t)
- t = t[-50:]
- t = sum(t) / len(t)
- if t < min_t:
- min_t = t
- DCNFunction.FW_TABLES[hash_value] = func
- candicate_func = func
- assert candicate_func is not None
- print(candicate_func)
- return candicate_func
- @staticmethod
- def find_bp_funcs(values, deformables, weights, grad_out):
- B, H, W, G, C = values.shape
- B, H, W, G, K = weights.shape
- hash_value = 10000 * B + 100 * H + W + 1000 * G
- if hash_value in DCNFunction.BP_TABLES.keys():
- return DCNFunction.BP_TABLES[hash_value]
- print("missing")
- candicate_func = None
- min_t = 999.0
- grad_values = torch.zeros_like(values)
- grad_deformables = torch.zeros_like(deformables)
- grad_weights = torch.zeros_like(weights)
- for func in DCNFunction.BP_FUNCS:
- t = []
- for i in range(100):
- torch.cuda.synchronize()
- start_t = time.time()
- func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
- torch.cuda.synchronize()
- t.append(time.time() - start_t)
- t = t[-50:]
- t = sum(t) / len(t)
- if t < min_t:
- min_t = t
- DCNFunction.BP_TABLES[hash_value] = func
- candicate_func = func
- assert candicate_func is not None
- print(candicate_func)
- return candicate_func
-
- @staticmethod
- @once_differentiable
- @custom_bwd
- def backward(ctx: Any, *grad_outputs: Any) -> Any:
- grad_out = grad_outputs[0]
- values, deformables, weights = ctx.saved_tensors
- B, H, W, G, C = values.shape
- B, H, W, G, K = weights.shape
- func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out)
- grad_values = torch.zeros_like(values)
- grad_deformables = torch.zeros_like(deformables)
- grad_weights = torch.zeros_like(weights)
- func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
- return grad_values, grad_deformables, grad_weights
\ No newline at end of file
diff --git a/src/ops/cuda_kernels/setup.py b/src/ops/cuda_kernels/setup.py
deleted file mode 100644
index 34079d4..0000000
--- a/src/ops/cuda_kernels/setup.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from setuptools import setup
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension
-
-setup(
- name='dcn_cuda_forward',
- ext_modules=[
- CUDAExtension('dcn_cuda_forward', ['./forward.cu',],
- extra_compile_args={'cxx': [], 'nvcc': [
- "-DCUDA_HAS_FP16=1",
- "-D__CUDA_NO_HALF_OPERATORS__",
- "-D__CUDA_NO_HALF_CONVERSIONS__",
- "-D__CUDA_NO_HALF2_OPERATORS__",
- "--use_fast_math",
- "-O3",
- ]}
- ),
- ],
- cmdclass={
- 'build_ext': BuildExtension
- }
-)
-
-setup(
- name='dcn_cuda_backward',
- ext_modules=[
- CUDAExtension('dcn_cuda_backward', ['./backward.cu',],
- extra_compile_args={'cxx': [], 'nvcc': [
- "-DCUDA_HAS_FP16=1",
- "-D__CUDA_NO_HALF_OPERATORS__",
- "-D__CUDA_NO_HALF_CONVERSIONS__",
- "-D__CUDA_NO_HALF2_OPERATORS__",
- "--use_fast_math",
- "-O3",
- ]}
- ),
- ],
- cmdclass={
- 'build_ext': BuildExtension
- }
-)
-
-
-# setup(
-# name='mycuda',
-# ext_modules=[
-# CUDAExtension('mycuda', ['./backward.cu',],
-# extra_compile_args={'cxx': [], 'nvcc': [
-# "-O3",
-# "-DCUDA_HAS_FP16=1",
-# "-D__CUDA_NO_HALF_OPERATORS__",
-# "-D__CUDA_NO_HALF_CONVERSIONS__",
-# "-D__CUDA_NO_HALF2_OPERATORS__",
-# ]}
-# ),
-# ],
-# cmdclass={
-# 'build_ext': BuildExtension
-# }
-# )
\ No newline at end of file
diff --git a/src/ops/triton_kernels/__init__.py b/src/ops/triton_kernels/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/src/ops/triton_kernels/backward.py b/src/ops/triton_kernels/backward.py
deleted file mode 100644
index e886aa2..0000000
--- a/src/ops/triton_kernels/backward.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import triton
-import triton.language as tl
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
- triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
- ],
- key=['B', 'H', 'W', 'G', 'C', 'K'],
-)
-@triton.jit
-def backward_kernel(
- B: tl.constexpr,
- H: tl.constexpr, # image_size_h
- W: tl.constexpr, # image_size_w
- G: tl.constexpr, # num_groups
- C: tl.constexpr, # num_channels_per_group
- K: tl.constexpr, # kernel size
- input_ptr, # input features [B, H, W, G, C]
- deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
- weights_ptr, # weights [B, H, W, G, K]
- grad_ptr, # out [B, H, W, G, C]
- grad_input_ptr, # input features [B, H, W, G, C]
- grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
- grad_weights_ptr, # weights [B, H, W, G, K]
- BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
-):
-
- pid = tl.program_id(0)
- wid = pid % W
- hid = pid // W % H
- gid = pid // (W * H) % G
- bid = pid // (W * H * G)
-
- id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
-
- common_offset = bid*H*W*G + hid*W*G + wid*G + gid
- batch_base = bid * H * W * G * C
- for k in tl.static_range(K):
- # load dynamic weight and mask
- weights_offset = common_offset*K + k
- weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
- dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
- dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
- dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
- deformable_offset = (common_offset * K + k)*2
- x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
- y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
- for block_base in tl.static_range(0, C, BLOCK_SIZE):
- block_offset = tl.arange(0, BLOCK_SIZE) + block_base
- block_mask = (block_offset < C) & id_mask
- grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
- dods = weight*grad
-
- floor_x = x.to(tl.int32)
- floor_y = y.to(tl.int32)
- ceil_x = floor_x + 1
- ceil_y = floor_y + 1
-
- # load top left
- tl_weight = (ceil_x - x) * (ceil_y - y)
- tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
- tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
- tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
- tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
- dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
- dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
- dodw = dodw + tl_block_input_dot_grad * tl_weight
-
- dodtl = dods * tl_weight
- tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
-
-
- # load top right
- tr_weight = (x - floor_x) * (ceil_y - y)
- tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
- tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
- tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
- tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
- dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
- dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
- dodw = dodw + tr_block_input_dot_grad*tr_weight
-
- dodtr = dods * tr_weight
- tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
-
-
- # load bottom left
- bl_weight = (ceil_x - x) * (y - floor_y)
- bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
- bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
- bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
- bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
- dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
- dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
- dodw = dodw + bl_block_input_dot_grad*bl_weight
-
- dodbl = dods * bl_weight
- tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
-
-
- # load bottom right
- br_weight = (x - floor_x) * (y - floor_y)
- br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
- br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
- br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
- br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
-
- dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
- dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
- dodw = dodw + br_block_input_dot_grad*br_weight
-
- dodbr = dods * br_weight
- tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
- dodx = dodx * weight
- dody = dody * weight
- tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
- tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
- tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)
-
-
-
-
-
diff --git a/src/ops/triton_kernels/forward.py b/src/ops/triton_kernels/forward.py
deleted file mode 100644
index cf7c243..0000000
--- a/src/ops/triton_kernels/forward.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import triton
-import triton.language as tl
-
-@triton.autotune(
- configs=[
- triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
- # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
- ],
- key=['B', 'H', 'W', 'G', 'C', 'K'],
-)
-@triton.jit
-def forward_kernel(
- B: tl.constexpr,
- H: tl.constexpr, # image_size_h
- W: tl.constexpr, # image_size_w
- G: tl.constexpr, # num_channels_per_group
- C: tl.constexpr, # num_groups
- K: tl.constexpr, # kernel size
- input_ptr, # input features [B, H, W, G, C]
- deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
- weights_ptr, # weights [B, H, W, G, K]
- out_ptr, # out [B, H, W, G, C]
- BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
-):
- pid = tl.program_id(0)
- wid = pid % W
- hid = pid // W % H
- gid = pid // (W * H) % G
- bid = pid // (W * H * G)
-
- id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
- common_offset = bid*H*W*G + hid*W*G + wid*G + gid
- batch_base = bid * H * W * G * C
-
- for block_base in tl.static_range(0, C, BLOCK_SIZE):
- buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
- block_offset = tl.arange(0, BLOCK_SIZE) + block_base
- block_mask = (block_offset < C) & id_mask
- for k in tl.static_range(K):
- deformable_offset = (common_offset * K + k) * 2
-
- x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
- y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
-
- floor_x = x.to(tl.int32)
- floor_y = y.to(tl.int32)
- ceil_x = floor_x + 1
- ceil_y = floor_y + 1
-
- # load top left
- tl_weight = (ceil_x - x) * (ceil_y - y)
- tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
- tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
-
- # load top right
- tr_weight = (x - floor_x) * (ceil_y - y)
- tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
- tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
- # load bottom left
- bl_weight = (ceil_x - x) * (y - floor_y)
- bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
- bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
- # load bottom right
- br_weight = (x - floor_x) * (y - floor_y)
- br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
- br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
-
- # load dynamic weight and mask
- weights_offset = common_offset*K + k
- weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
-
-
-
- tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
- tl_block_input = tl_block_input * tl_weight
-
- # load top right
- tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
- tr_block_input = tr_block_input * tr_weight
- # load bottom left
- bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
- bl_block_input = bl_block_input * bl_weight
- # load bottom right
- br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
- br_block_input = br_block_input * br_weight
-
- # sampled
- sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
-
- weighted_sampled_input = sampled_input * weight
- buffer = buffer + weighted_sampled_input
- # store to out_ptr
- tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
-
diff --git a/src/ops/triton_kernels/function.py b/src/ops/triton_kernels/function.py
deleted file mode 100644
index 84987a1..0000000
--- a/src/ops/triton_kernels/function.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import torch
-import triton
-from typing import Any
-from torch.autograd import Function
-from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
-from .forward import forward_kernel
-from .backward import backward_kernel
-
-
-
-class DCNFunction(Function):
-
- @staticmethod
- @custom_fwd
- def forward(ctx: Any, inputs, deformables, weights) -> Any:
- B, H, W, G, C = inputs.shape
- _, _, _, _, K, _ = deformables.shape
- out = torch.zeros_like(inputs)
- grid = lambda META: (B * H * W * G,)
-
- forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
- ctx.save_for_backward(inputs, deformables, weights)
- return out
-
- @staticmethod
- @custom_bwd
- def backward(ctx: Any, *grad_outputs: Any) -> Any:
- grad_output = grad_outputs[0].contiguous()
-
- inputs, deformables, weights = ctx.saved_tensors
- B, H, W, G, C = inputs.shape
- _, _, _, _, K, _ = deformables.shape
-
- grad_inputs = torch.zeros_like(inputs)
- grad_deformables = torch.zeros_like(deformables)
- grad_weights = torch.zeros_like(weights)
- grid = lambda META: (B * H * W * G,)
- backward_kernel[grid](
- B, H, W, G, C, K,
- inputs,
- deformables,
- weights,
- grad_output,
- grad_inputs,
- grad_deformables,
- grad_weights,
- )
- return (grad_inputs, grad_deformables, grad_weights)
\ No newline at end of file