diff --git a/.idea/DDT.iml b/.idea/DDT.iml index d0876a7..8b8c395 100644 --- a/.idea/DDT.iml +++ b/.idea/DDT.iml @@ -5,4 +5,8 @@ + + \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_pyramid.py b/src/diffusion/flow_matching/training_pyramid.py deleted file mode 100644 index be2bd94..0000000 --- a/src/diffusion/flow_matching/training_pyramid.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class PyramidTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - - output_pyramid = [] - def feature_hook(module, input, output): - output_pyramid.extend(output) - handle = net.decoder.register_forward_hook(feature_hook) - net(x_t, t, y) - handle.remove() - - loss = 0.0 - out_dict = dict() - - cur_v_t = v_t - for i in range(len(output_pyramid)): - cur_out = output_pyramid[i] - loss_i = (cur_v_t - cur_out) ** 2 - loss += loss_i.mean() - out_dict["loss_{}".format(i)] = loss_i.mean() - cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False) - out_dict["loss"] = loss - return out_dict - diff --git a/src/diffusion/flow_matching/training_repa_mask.py b/src/diffusion/flow_matching/training_repa_mask.py deleted file mode 100644 index f8c4edb..0000000 --- a/src/diffusion/flow_matching/training_repa_mask.py +++ /dev/null @@ -1,152 +0,0 @@ -import torch -import copy -import timm -from torch.nn import Parameter - -from src.utils.no_grad import no_grad -from typing import Callable, Iterator, Tuple -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torchvision.transforms import Normalize -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class DINOv2(nn.Module): - def __init__(self, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.precomputed_pos_embed = dict() - - def fetch_pos(self, h, w): - key = (h, w) - if key in self.precomputed_pos_embed: - return self.precomputed_pos_embed[key] - value = timm.layers.pos_embed.resample_abs_pos_embed( - self.pos_embed.data, [h, w], - ) - self.precomputed_pos_embed[key] = value - return value - - def forward(self, x): - b, c, h, w = x.shape - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) - self.encoder.pos_embed.data = pos_embed_data - feature = self.encoder.forward_features(x)['x_norm_patchtokens'] - return feature - - -class REPATrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - feat_loss_weight: float=0.5, - lognorm_t=False, - mask_ratio=0.0, - mask_patch_size=2, - encoder_weight_path=None, - align_layer=8, - proj_denoiser_dim=256, - proj_hidden_dim=256, - proj_encoder_dim=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.mask_ratio = mask_ratio - self.mask_patch_size = mask_patch_size - self.feat_loss_weight = feat_loss_weight - self.align_layer = align_layer - self.encoder = DINOv2(encoder_weight_path) - no_grad(self.encoder) - - self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(proj_denoiser_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_encoder_dim), - ) - ) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size, c, height, width = x.shape - if self.lognorm_t: - base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() - else: - base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) - t = base_t - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - - patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device) - patch_mask = (patch_mask < self.mask_ratio).float() - mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest') - masked_x = x*(1-mask)# + torch.randn_like(x)*(mask) - - x_t = alpha*masked_x + sigma*noise - v_t = dalpha*x + dsigma*noise - - src_feature = [] - def forward_hook(net, input, output): - src_feature.append(output) - handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) - - v_t_out, x0_out = net(x_t, t, y) - src_feature = self.proj(src_feature[0]) - handle.remove() - - with torch.no_grad(): - dst_feature = self.encoder(raw_images) - - cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) - cos_loss = 1 - cos_sim - - weight = self.loss_weight_fn(alpha, sigma) - fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean()) - mask_loss = mask*weight*(x0_out - x)**2/(mask.mean()) - out = dict( - fm_loss=fm_loss.mean(), - cos_loss=cos_loss.mean(), - mask_loss=mask_loss.mean(), - loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(), - ) - return out - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - self.proj.state_dict( - destination=destination, - prefix=prefix + "proj.", - keep_vars=keep_vars) - diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv.py b/src/diffusion/stateful_flow_matching/bak/training_adv.py deleted file mode 100644 index 4792950..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_adv.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class Discriminator(nn.Module): - def __init__(self, in_channels, hidden_size): - super().__init__() - self.head = nn.Sequential( - nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 - ) - - def forward(self, feature): - B, L, C = feature.shape - H = W = int(math.sqrt(L)) - feature = feature.permute(0, 2, 1) - feature = feature.view(B, C, H, W) - out = self.head(feature).sigmoid().clamp(0.01, 0.99) - return out - -class AdvTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - adv_weight=1.0, - adv_encoder_layer=4, - adv_in_channels=768, - adv_hidden_size=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.adv_weight = adv_weight - self.adv_encoder_layer = adv_encoder_layer - - self.dis_head = Discriminator( - in_channels=adv_in_channels, - hidden_size=adv_hidden_size, - ) - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - adv_feature = [] - def forward_hook(net, input, output): - adv_feature.append(output) - handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) - - out, _ = net(x_t, t, y) - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - - pred_x0 = (x_t + out * sigma) - pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma - real_feature = adv_feature.pop() - net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) - fake_feature = adv_feature.pop() - handle.remove() - - - real_score_gan = self.dis_head(real_feature.detach()) - fake_score_gan = self.dis_head(fake_feature.detach()) - fake_score = self.dis_head(fake_feature) - - loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) - acc_real = (real_score_gan > 0.5).float() - acc_fake = (fake_score_gan < 0.5).float() - loss_adv = -torch.log(fake_score) - loss_adv_hack = torch.log(fake_score_gan) - - out = dict( - adv_loss=loss_adv.mean(), - gan_loss=loss_gan.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), - acc_real=acc_real.mean(), - acc_fake=acc_fake.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py deleted file mode 100644 index 2843c04..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class Discriminator(nn.Module): - def __init__(self, in_channels, hidden_size): - super().__init__() - self.head = nn.Sequential( - nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 - ) - - def forward(self, feature): - B, L, C = feature.shape - H = W = int(math.sqrt(L)) - feature = feature.permute(0, 2, 1) - feature = feature.view(B, C, H, W) - out = self.head(feature).sigmoid().clamp(0.01, 0.99) - return out - -class AdvTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - adv_weight=1.0, - lpips_weight=1.0, - adv_encoder_layer=4, - adv_in_channels=768, - adv_hidden_size=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.adv_weight = adv_weight - self.lpips_weight = lpips_weight - self.adv_encoder_layer = adv_encoder_layer - - self.dis_head = Discriminator( - in_channels=adv_in_channels, - hidden_size=adv_hidden_size, - ) - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - out, _ = net(x_t, t, y) - pred_x0 = (x_t + out * sigma) - - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - with torch.no_grad(): - _, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer) - _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer) - - real_score_gan = self.dis_head(real_features[-1].detach()) - fake_score_gan = self.dis_head(fake_features[-1].detach()) - fake_score = self.dis_head(fake_features[-1]) - - loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) - acc_real = (real_score_gan > 0.5).float() - acc_fake = (fake_score_gan < 0.5).float() - loss_adv = -torch.log(fake_score) - loss_adv_hack = torch.log(fake_score_gan) - - lpips_loss = [] - for r, f in zip(real_features, fake_features): - r = torch.nn.functional.normalize(r, dim=-1) - f = torch.nn.functional.normalize(f, dim=-1) - lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) - lpips_loss = sum(lpips_loss) - - - out = dict( - adv_loss=loss_adv.mean(), - gan_loss=loss_gan.mean(), - lpips_loss=lpips_loss.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(), - acc_real=acc_real.mean(), - acc_fake=acc_fake.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py deleted file mode 100644 index 849ee4b..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py +++ /dev/null @@ -1,159 +0,0 @@ -import random - -import torch -import copy -import timm -from torch.nn import Parameter - -from src.utils.no_grad import no_grad -from typing import Callable, Iterator, Tuple -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torchvision.transforms import Normalize -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class DINOv2(nn.Module): - def __init__(self, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.precomputed_pos_embed = dict() - - def fetch_pos(self, h, w): - key = (h, w) - if key in self.precomputed_pos_embed: - return self.precomputed_pos_embed[key] - value = timm.layers.pos_embed.resample_abs_pos_embed( - self.pos_embed.data, [h, w], - ) - self.precomputed_pos_embed[key] = value - return value - - def forward(self, x): - b, c, h, w = x.shape - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) - self.encoder.pos_embed.data = pos_embed_data - feature = self.encoder.forward_features(x)['x_norm_patchtokens'] - return feature - - -class MaskREPATrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - feat_loss_weight: float=0.5, - lognorm_t=False, - encoder_weight_path=None, - mask_groups=4, - align_layer=8, - proj_denoiser_dim=256, - proj_hidden_dim=256, - proj_encoder_dim=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.feat_loss_weight = feat_loss_weight - self.align_layer = align_layer - self.mask_groups = mask_groups - self.encoder = DINOv2(encoder_weight_path) - no_grad(self.encoder) - - self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(proj_denoiser_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_encoder_dim), - ) - ) - - def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')): - mask = torch.zeros(1, length, length, device=device, dtype=torch.bool) - random_seq = torch.randperm(length, device=device) - for i in range(groups): - group_start = (length+groups-1)//groups*i - group_end = (length+groups-1)//groups*(i+1) - group_random_seq = random_seq[group_start:group_end] - y, x = torch.meshgrid(group_random_seq, group_random_seq) - mask[:, y, x] = True - return mask - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size, c, height, width = x.shape - if self.lognorm_t: - base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() - else: - base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) - t = base_t - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - src_feature = [] - def forward_hook(net, input, output): - src_feature.append(output) - handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) - mask_groups = random.randint(1, self.mask_groups) - mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device) - out, _ = net(x_t, t, y, mask=mask) - src_feature = self.proj(src_feature[0]) - handle.remove() - - with torch.no_grad(): - dst_feature = self.encoder(raw_images) - - cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) - cos_loss = 1 - cos_sim - - weight = self.loss_weight_fn(alpha, sigma) - fm_loss = weight*(out - v_t)**2 - - out = dict( - fm_loss=fm_loss.mean(), - cos_loss=cos_loss.mean(), - loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), - ) - return out - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - self.proj.state_dict( - destination=destination, - prefix=prefix + "proj.", - keep_vars=keep_vars) - diff --git a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py deleted file mode 100644 index 229680c..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py +++ /dev/null @@ -1,179 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t, mul=1000): - t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - - -class BatchNormWithTimeEmbedding(nn.Module): - def __init__(self, num_features): - super().__init__() - # self.bn = nn.BatchNorm2d(num_features, affine=False) - self.bn = nn.GroupNorm(16, num_features, affine=False) - # self.bn = nn.SyncBatchNorm(num_features, affine=False) - self.embedder = TimestepEmbedder(num_features * 2) - # nn.init.zeros_(self.embedder.mlp[-1].weight) - nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01) - nn.init.zeros_(self.embedder.mlp[-1].bias) - - def forward(self, x, t): - embed = self.embedder(t) - embed = embed[:, :, None, None] - gamma, beta = embed.chunk(2, dim=1) - gamma = 1.0 + gamma - normed = self.bn(x) - out = normed * gamma + beta - return out - -class DisBlock(nn.Module): - def __init__(self, in_channels, hidden_size): - super().__init__() - self.conv = nn.Conv2d( - kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0 - ) - self.norm = BatchNormWithTimeEmbedding(hidden_size) - self.act = nn.SiLU() - def forward(self, x, t): - x = self.conv(x) - x = self.norm(x, t) - x = self.act(x) - return x - - -class Discriminator(nn.Module): - def __init__(self, num_blocks, in_channels, hidden_size): - super().__init__() - self.blocks = nn.ModuleList() - for i in range(num_blocks): - self.blocks.append( - DisBlock( - in_channels=in_channels, - hidden_size=hidden_size, - ) - ) - in_channels = hidden_size - self.classifier = nn.Conv2d( - kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1 - ) - def forward(self, feature, t): - B, C, H, W = feature.shape - for block in self.blocks: - feature = block(feature, t) - out = self.classifier(feature).view(B, -1) - out = out.sigmoid().clamp(0.01, 0.99) - return out - -class AdvTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - adv_weight=1.0, - adv_blocks=3, - adv_in_channels=3, - adv_hidden_size=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.adv_weight = adv_weight - - self.discriminator = Discriminator( - num_blocks=adv_blocks, - in_channels=adv_in_channels*2, - hidden_size=adv_hidden_size, - ) - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - out, _ = net(x_t, t, y) - pred_x0 = x_t + sigma * out - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - - real_feature = torch.cat([x_t, x], dim=1) - fake_feature = torch.cat([x_t, pred_x0], dim=1) - - real_score_gan = self.discriminator(real_feature.detach(), t) - fake_score_gan = self.discriminator(fake_feature.detach(), t) - fake_score = self.discriminator(fake_feature, t) - - loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) - acc_real = (real_score_gan > 0.5).float() - acc_fake = (fake_score_gan < 0.5).float() - loss_adv = -torch.log(fake_score) - loss_adv_hack = torch.log(fake_score_gan) - - out = dict( - adv_loss=loss_adv.mean(), - gan_loss=loss_gan.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), - acc_real=acc_real.mean(), - acc_fake=acc_fake.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py deleted file mode 100644 index e84e81f..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py +++ /dev/null @@ -1,154 +0,0 @@ -import torch -import copy -import timm -from torch.nn import Parameter - -from src.utils.no_grad import no_grad -from typing import Callable, Iterator, Tuple -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torchvision.transforms import Normalize -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class DINOv2(nn.Module): - def __init__(self, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.precomputed_pos_embed = dict() - - def fetch_pos(self, h, w): - key = (h, w) - if key in self.precomputed_pos_embed: - return self.precomputed_pos_embed[key] - value = timm.layers.pos_embed.resample_abs_pos_embed( - self.pos_embed.data, [h, w], - ) - self.precomputed_pos_embed[key] = value - return value - - def forward(self, x): - b, c, h, w = x.shape - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) - self.encoder.pos_embed.data = pos_embed_data - feature = self.encoder.forward_features(x)['x_norm_patchtokens'] - return feature - - -class REPAJiTTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - feat_loss_weight: float=0.5, - lognorm_t=False, - jit_deltas=0.01, - encoder_weight_path=None, - align_layer=8, - proj_denoiser_dim=256, - proj_hidden_dim=256, - proj_encoder_dim=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.feat_loss_weight = feat_loss_weight - self.align_layer = align_layer - self.jit_deltas = jit_deltas - self.encoder = DINOv2(encoder_weight_path) - no_grad(self.encoder) - - self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(proj_denoiser_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_encoder_dim), - ) - ) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size, c, height, width = x.shape - if self.lognorm_t: - base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() - else: - base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) - t = base_t - - noise = torch.randn_like(x) - - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas - t2 = torch.clip(t2, 0, 1) - alpha = self.scheduler.alpha(t2) - dalpha = self.scheduler.dalpha(t2) - sigma = self.scheduler.sigma(t2) - dsigma = self.scheduler.dsigma(t2) - x_t2 = alpha * x + noise * sigma - v_t2 = dalpha * x + dsigma * noise - - src_feature = [] - def forward_hook(net, input, output): - src_feature.append(output) - handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) - - _, s = net(x_t, t, y, only_s=True) - out, _ = net(x_t2, t2, y, s) - src_feature = self.proj(src_feature[0]) - handle.remove() - - with torch.no_grad(): - dst_feature = self.encoder(raw_images) - - cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) - cos_loss = 1 - cos_sim - - weight = self.loss_weight_fn(alpha, sigma) - fm_loss = weight*(out - v_t2)**2 - - out = dict( - fm_loss=fm_loss.mean(), - cos_loss=cos_loss.mean(), - loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), - ) - return out - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - self.proj.state_dict( - destination=destination, - prefix=prefix + "proj.", - keep_vars=keep_vars) - diff --git a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py deleted file mode 100644 index d7e741d..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py +++ /dev/null @@ -1,90 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class SelfConsistentTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - lpips_weight=1.0, - lpips_encoder_layer=4, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.lpips_encoder_layer = lpips_encoder_layer - self.lpips_weight = lpips_weight - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - real_features = [] - def forward_hook(net, input, output): - real_features.append(output) - handles = [] - for i in range(self.lpips_encoder_layer): - handle = net.encoder.blocks[i].register_forward_hook(forward_hook) - handles.append(handle) - - out, _ = net(x_t, t, y) - - for handle in handles: - handle.remove() - - pred_x0 = (x_t + out * sigma) - pred_xt = alpha * pred_x0 + noise * sigma - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - - _, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer) - - lpips_loss = [] - for r, f in zip(real_features, fake_features): - r = torch.nn.functional.normalize(r, dim=-1) - f = torch.nn.functional.normalize(f, dim=-1) - lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) - lpips_loss = sum(lpips_loss) - - - out = dict( - lpips_loss=lpips_loss.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py deleted file mode 100644 index 580775b..0000000 --- a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class SelfLPIPSTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - lpips_weight=1.0, - lpips_encoder_layer=4, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.lpips_encoder_layer = lpips_encoder_layer - self.lpips_weight = lpips_weight - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - out, _ = net(x_t, t, y) - pred_x0 = (x_t + out * sigma) - pred_xt = alpha * pred_x0 + noise * sigma - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - with torch.no_grad(): - _, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer) - _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer) - - - lpips_loss = [] - for r, f in zip(real_features, fake_features): - r = torch.nn.functional.normalize(r, dim=-1) - f = torch.nn.functional.normalize(f, dim=-1) - lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) - lpips_loss = sum(lpips_loss) - - - out = dict( - lpips_loss=lpips_loss.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/cm_sampling.py b/src/diffusion/stateful_flow_matching/cm_sampling.py deleted file mode 100644 index 5254db5..0000000 --- a/src/diffusion/stateful_flow_matching/cm_sampling.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch - -from src.diffusion.base.guidance import * -from src.diffusion.base.scheduling import * -from src.diffusion.base.sampling import * - -from typing import Callable - - -def shift_respace_fn(t, shift=3.0): - return t / (t + (1 - t) * shift) - -def ode_step_fn(x, v, dt, s, w): - return x + v * dt - - -import logging -logger = logging.getLogger(__name__) - -class CMSampler(BaseSampler): - def __init__( - self, - w_scheduler: BaseScheduler = None, - timeshift=1.0, - guidance_interval_min: float = 0.0, - guidance_interval_max: float = 1.0, - state_refresh_rate=1, - last_step=None, - step_fn=None, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.last_step = last_step - self.timeshift = timeshift - self.state_refresh_rate = state_refresh_rate - self.guidance_interval_min = guidance_interval_min - self.guidance_interval_max = guidance_interval_max - - if self.last_step is None or self.num_steps == 1: - self.last_step = 1.0 / self.num_steps - - timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) - timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) - self.timesteps = shift_respace_fn(timesteps, self.timeshift) - - assert self.last_step > 0.0 - assert self.scheduler is not None - - - def _impl_sampling(self, net, noise, condition, uncondition): - """ - sampling process of Euler sampler - - - """ - batch_size = noise.shape[0] - steps = self.timesteps.to(noise.device) - cfg_condition = torch.cat([uncondition, condition], dim=0) - x = noise - state = None - for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): - cfg_t = t_cur.repeat(batch_size*2) - cfg_x = torch.cat([x, x], dim=0) - if i % self.state_refresh_rate == 0: - state = None - out, state = net(cfg_x, cfg_t, cfg_condition, state) - if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max: - out = self.guidance_fn(out, self.guidance) - else: - out = self.guidance_fn(out, 1.0) - v = out - - x0 = x + v * (1-t_cur) - alpha_next = self.scheduler.alpha(t_next) - sigma_next = self.scheduler.sigma(t_next) - x = alpha_next * x0 + sigma_next * torch.randn_like(x) - # print(alpha_next, sigma_next) - return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py deleted file mode 100644 index f372028..0000000 --- a/src/diffusion/stateful_flow_matching/sharing_sampling.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch - -from src.diffusion.base.guidance import * -from src.diffusion.base.scheduling import * -from src.diffusion.base.sampling import * - -from typing import Callable - - -def shift_respace_fn(t, shift=3.0): - return t / (t + (1 - t) * shift) - -def ode_step_fn(x, v, dt, s, w): - return x + v * dt - - -import logging -logger = logging.getLogger(__name__) - -class EulerSampler(BaseSampler): - def __init__( - self, - w_scheduler: BaseScheduler = None, - timeshift=1.0, - guidance_interval_min: float = 0.0, - guidance_interval_max: float = 1.0, - state_refresh_rate=1, - step_fn: Callable = ode_step_fn, - last_step=None, - last_step_fn: Callable = ode_step_fn, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.step_fn = step_fn - self.last_step = last_step - self.last_step_fn = last_step_fn - self.w_scheduler = w_scheduler - self.timeshift = timeshift - self.state_refresh_rate = state_refresh_rate - self.guidance_interval_min = guidance_interval_min - self.guidance_interval_max = guidance_interval_max - - if self.last_step is None or self.num_steps == 1: - self.last_step = 1.0 / self.num_steps - - timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) - timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) - self.timesteps = shift_respace_fn(timesteps, self.timeshift) - - assert self.last_step > 0.0 - assert self.scheduler is not None - assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] - if self.w_scheduler is not None: - if self.step_fn == ode_step_fn: - logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") - - # init recompute - self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate) - self.recompute_timesteps = list(range(self.num_steps)) - - def sharing_dp(self, net, noise, condition, uncondition): - _, C, H, W = noise.shape - B = 8 - template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device) - template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device) - template_uncondition = torch.full((B, ), 1000, device=condition.device) - _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition) - states = torch.stack(state_list) - N, B, L, C = states.shape - states = states.view(N, B*L, C ) - states = states.permute(1, 0, 2) - states = torch.nn.functional.normalize(states, dim=-1) - with torch.autocast(device_type="cuda", dtype=torch.float64): - sim = torch.bmm(states, states.transpose(1, 2)) - sim = torch.mean(sim, dim=0).cpu() - error_map = (1-sim).tolist() - - # init cum-error - for i in range(1, self.num_steps): - for j in range(0, i): - error_map[i][j] = error_map[i-1][j] + error_map[i][j] - - # init dp and force 0 start - C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] - P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] - for i in range(1, self.num_steps+1): - C[1][i] = error_map[i - 1][0] - P[1][i] = 0 - - # dp state - for step in range(2, self.num_recompute_timesteps+1): - for i in range(step, self.num_steps+1): - min_value = 99999 - min_index = -1 - for j in range(step-1, i): - value = C[step-1][j] + error_map[i-1][j] - if value < min_value: - min_value = value - min_index = j - C[step][i] = min_value - P[step][i] = min_index - - # trace back - timesteps = [self.num_steps,] - for i in range(self.num_recompute_timesteps, 0, -1): - idx = timesteps[-1] - timesteps.append(P[i][idx]) - timesteps.reverse() - - print("recompute timesteps solved by DP: ", timesteps) - return timesteps[:-1] - - def _impl_sampling(self, net, noise, condition, uncondition): - """ - sampling process of Euler sampler - - - """ - batch_size = noise.shape[0] - steps = self.timesteps.to(noise.device) - cfg_condition = torch.cat([uncondition, condition], dim=0) - x = noise - state = None - pooled_state_list = [] - for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): - dt = t_next - t_cur - t_cur = t_cur.repeat(batch_size) - cfg_x = torch.cat([x, x], dim=0) - cfg_t = t_cur.repeat(2) - if i in self.recompute_timesteps: - state = None - out, state = net(cfg_x, cfg_t, cfg_condition, state) - if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: - out = self.guidance_fn(out, self.guidance) - else: - out = self.guidance_fn(out, 1.0) - v = out - if i < self.num_steps -1 : - x = self.step_fn(x, v, dt, s=0.0, w=0.0) - else: - x = self.last_step_fn(x, v, dt, s=0.0, w=0.0) - pooled_state_list.append(state) - return x, pooled_state_list - - def __call__(self, net, noise, condition, uncondition): - if len(self.recompute_timesteps) != self.num_recompute_timesteps: - self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition) - denoised, _ = self._impl_sampling(net, noise, condition, uncondition) - return denoised \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_adv.py b/src/diffusion/stateful_flow_matching/training_adv.py deleted file mode 100644 index 4792950..0000000 --- a/src/diffusion/stateful_flow_matching/training_adv.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import math -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class Discriminator(nn.Module): - def __init__(self, in_channels, hidden_size): - super().__init__() - self.head = nn.Sequential( - nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 - nn.GroupNorm(num_groups=32, num_channels=hidden_size), - nn.SiLU(), - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 - ) - - def forward(self, feature): - B, L, C = feature.shape - H = W = int(math.sqrt(L)) - feature = feature.permute(0, 2, 1) - feature = feature.view(B, C, H, W) - out = self.head(feature).sigmoid().clamp(0.01, 0.99) - return out - -class AdvTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - adv_weight=1.0, - adv_encoder_layer=4, - adv_in_channels=768, - adv_hidden_size=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.adv_weight = adv_weight - self.adv_encoder_layer = adv_encoder_layer - - self.dis_head = Discriminator( - in_channels=adv_in_channels, - hidden_size=adv_hidden_size, - ) - - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - adv_feature = [] - def forward_hook(net, input, output): - adv_feature.append(output) - handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) - - out, _ = net(x_t, t, y) - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - - pred_x0 = (x_t + out * sigma) - pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma - real_feature = adv_feature.pop() - net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) - fake_feature = adv_feature.pop() - handle.remove() - - - real_score_gan = self.dis_head(real_feature.detach()) - fake_score_gan = self.dis_head(fake_feature.detach()) - fake_score = self.dis_head(fake_feature) - - loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) - acc_real = (real_score_gan > 0.5).float() - acc_fake = (fake_score_gan < 0.5).float() - loss_adv = -torch.log(fake_score) - loss_adv_hack = torch.log(fake_score_gan) - - out = dict( - adv_loss=loss_adv.mean(), - gan_loss=loss_gan.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), - acc_real=acc_real.mean(), - acc_fake=acc_fake.mean(), - ) - return out diff --git a/src/diffusion/stateful_flow_matching/training_distill_dino.py b/src/diffusion/stateful_flow_matching/training_distill_dino.py deleted file mode 100644 index c6a2937..0000000 --- a/src/diffusion/stateful_flow_matching/training_distill_dino.py +++ /dev/null @@ -1,141 +0,0 @@ -import torch -import copy -import timm -from torch.nn import Parameter - -from src.utils.no_grad import no_grad -from typing import Callable, Iterator, Tuple -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torchvision.transforms import Normalize -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class DINOv2(nn.Module): - def __init__(self, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.precomputed_pos_embed = dict() - - def fetch_pos(self, h, w): - key = (h, w) - if key in self.precomputed_pos_embed: - return self.precomputed_pos_embed[key] - value = timm.layers.pos_embed.resample_abs_pos_embed( - self.pos_embed.data, [h, w], - ) - self.precomputed_pos_embed[key] = value - return value - - def forward(self, x): - b, c, h, w = x.shape - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) - self.encoder.pos_embed.data = pos_embed_data - feature = self.encoder.forward_features(x)['x_norm_patchtokens'] - return feature - - -class DistillDINOTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - feat_loss_weight: float=0.5, - lognorm_t=False, - encoder_weight_path=None, - align_layer=8, - proj_denoiser_dim=256, - proj_hidden_dim=256, - proj_encoder_dim=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.feat_loss_weight = feat_loss_weight - self.align_layer = align_layer - self.encoder = DINOv2(encoder_weight_path) - self.proj_encoder_dim = proj_encoder_dim - no_grad(self.encoder) - - self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(proj_denoiser_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_encoder_dim), - ) - ) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size, c, height, width = x.shape - if self.lognorm_t: - base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() - else: - base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) - t = base_t - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - sigma = self.scheduler.sigma(t) - - x_t = alpha * x + noise * sigma - - _, s = net(x_t, t, y) - src_feature = self.proj(s) - - with torch.no_grad(): - dst_feature = self.encoder(raw_images) - - if dst_feature.shape[1] != src_feature.shape[1]: - dst_length = dst_feature.shape[1] - rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 - dst_height = (dst_length)**0.5 * (height/width)**0.5 - dst_width = (dst_length)**0.5 * (width/height)**0.5 - dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) - dst_feature = dst_feature.permute(0, 3, 1, 2) - dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) - dst_feature = dst_feature.permute(0, 2, 3, 1) - dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) - - cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) - cos_loss = 1 - cos_sim - - out = dict( - cos_loss=cos_loss.mean(), - loss=cos_loss.mean(), - ) - return out - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - self.proj.state_dict( - destination=destination, - prefix=prefix + "proj.", - keep_vars=keep_vars) - diff --git a/src/diffusion/stateful_flow_matching/training_lpips.py b/src/diffusion/stateful_flow_matching/training_lpips.py deleted file mode 100644 index a3cd2a2..0000000 --- a/src/diffusion/stateful_flow_matching/training_lpips.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class LPIPSTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - lpips_weight=1.0, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.lpips_weight = lpips_weight - self.lpips = _NoTrainLpips(net="vgg") - self.lpips = self.lpips.to(torch.bfloat16) - # self.lpips = torch.compile(self.lpips) - no_grad(self.lpips) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - out, _ = net(x_t, t, y) - weight = self.loss_weight_fn(alpha, sigma) - loss = weight*(out - v_t)**2 - - pred_x0 = (x_t + out*sigma) - target_x0 = x - # fixbug lpips std - lpips = self.lpips(pred_x0*0.5, target_x0*0.5) - - out = dict( - lpips_loss=lpips.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + lpips.mean()*self.lpips_weight, - ) - return out - - def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py deleted file mode 100644 index e0233ea..0000000 --- a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from typing import Callable -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from src.utils.no_grad import no_grad -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - -class LPIPSTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - lognorm_t=False, - lpips_weight=1.0, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = False - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.lpips_weight = lpips_weight - self.lpips = _NoTrainLpips(net="vgg") - self.lpips = self.lpips.to(torch.bfloat16) - # self.lpips = torch.compile(self.lpips) - no_grad(self.lpips) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size = x.shape[0] - if self.lognorm_t: - t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() - else: - t = torch.rand(batch_size).to(x.device, x.dtype) - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - w = self.scheduler.w(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - out, _ = net(x_t, t, y) - - fm_weight = t*(1-t)**2/0.25 - lpips_weight = t - - loss = (out - v_t)**2 * fm_weight[:, None, None, None] - - pred_x0 = (x_t + out*sigma) - target_x0 = x - # fixbug lpips std - lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None] - - out = dict( - lpips_loss=lpips.mean(), - fm_loss=loss.mean(), - loss=loss.mean() + lpips.mean()*self.lpips_weight, - ) - return out - - def state_dict(self, *args, destination=None, prefix='', keep_vars=False): - return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_repa_lpips.py b/src/diffusion/stateful_flow_matching/training_repa_lpips.py deleted file mode 100644 index 5a11207..0000000 --- a/src/diffusion/stateful_flow_matching/training_repa_lpips.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import copy -import timm -from torch.nn import Parameter - -from src.utils.no_grad import no_grad -from typing import Callable, Iterator, Tuple -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from torchvision.transforms import Normalize -from src.diffusion.base.training import * -from src.diffusion.base.scheduling import BaseScheduler -from torchmetrics.image.lpip import _NoTrainLpips - -def inverse_sigma(alpha, sigma): - return 1/sigma**2 -def snr(alpha, sigma): - return alpha/sigma -def minsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, min=threshold) -def maxsnr(alpha, sigma, threshold=5): - return torch.clip(alpha/sigma, max=threshold) -def constant(alpha, sigma): - return 1 - - -class DINOv2(nn.Module): - def __init__(self, weight_path:str): - super(DINOv2, self).__init__() - self.encoder = torch.hub.load( - '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', - weight_path, - source="local", - skip_validation=True - ) - self.pos_embed = copy.deepcopy(self.encoder.pos_embed) - self.encoder.head = torch.nn.Identity() - self.patch_size = self.encoder.patch_embed.patch_size - self.precomputed_pos_embed = dict() - - def fetch_pos(self, h, w): - key = (h, w) - if key in self.precomputed_pos_embed: - return self.precomputed_pos_embed[key] - value = timm.layers.pos_embed.resample_abs_pos_embed( - self.pos_embed.data, [h, w], - ) - self.precomputed_pos_embed[key] = value - return value - - def forward(self, x): - b, c, h, w = x.shape - x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) - x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') - b, c, h, w = x.shape - patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] - pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) - self.encoder.pos_embed.data = pos_embed_data - feature = self.encoder.forward_features(x)['x_norm_patchtokens'] - return feature - - -class REPALPIPSTrainer(BaseTrainer): - def __init__( - self, - scheduler: BaseScheduler, - loss_weight_fn:Callable=constant, - feat_loss_weight: float=0.5, - lognorm_t=False, - lpips_weight=1.0, - encoder_weight_path=None, - align_layer=8, - proj_denoiser_dim=256, - proj_hidden_dim=256, - proj_encoder_dim=256, - *args, - **kwargs - ): - super().__init__(*args, **kwargs) - self.lognorm_t = lognorm_t - self.scheduler = scheduler - self.loss_weight_fn = loss_weight_fn - self.feat_loss_weight = feat_loss_weight - self.align_layer = align_layer - self.encoder = DINOv2(encoder_weight_path) - self.proj_encoder_dim = proj_encoder_dim - no_grad(self.encoder) - - self.lpips_weight = lpips_weight - self.lpips = _NoTrainLpips(net="vgg") - self.lpips = self.lpips.to(torch.bfloat16) - no_grad(self.lpips) - - self.proj = nn.Sequential( - nn.Sequential( - nn.Linear(proj_denoiser_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_hidden_dim), - nn.SiLU(), - nn.Linear(proj_hidden_dim, proj_encoder_dim), - ) - ) - - def _impl_trainstep(self, net, ema_net, raw_images, x, y): - batch_size, c, height, width = x.shape - if self.lognorm_t: - base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() - else: - base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) - t = base_t - - noise = torch.randn_like(x) - alpha = self.scheduler.alpha(t) - dalpha = self.scheduler.dalpha(t) - sigma = self.scheduler.sigma(t) - dsigma = self.scheduler.dsigma(t) - - x_t = alpha * x + noise * sigma - v_t = dalpha * x + dsigma * noise - - src_feature = [] - def forward_hook(net, input, output): - src_feature.append(output) - if getattr(net, "blocks", None) is not None: - handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) - else: - handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) - - out, _ = net(x_t, t, y) - src_feature = self.proj(src_feature[0]) - handle.remove() - - with torch.no_grad(): - dst_feature = self.encoder(raw_images) - - if dst_feature.shape[1] != src_feature.shape[1]: - dst_length = dst_feature.shape[1] - rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 - dst_height = (dst_length)**0.5 * (height/width)**0.5 - dst_width = (dst_length)**0.5 * (width/height)**0.5 - dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) - dst_feature = dst_feature.permute(0, 3, 1, 2) - dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) - dst_feature = dst_feature.permute(0, 2, 3, 1) - dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) - - cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) - cos_loss = 1 - cos_sim - - weight = self.loss_weight_fn(alpha, sigma) - fm_loss = weight*(out - v_t)**2 - - pred_x0 = (x_t + out * sigma) - target_x0 = x - # fixbug lpips std - lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5) - - out = dict( - lpips_loss=lpips.mean(), - fm_loss=fm_loss.mean(), - cos_loss=cos_loss.mean(), - loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(), - ) - return out - - def state_dict(self, *args, destination=None, prefix="", keep_vars=False): - self.proj.state_dict( - destination=destination, - prefix=prefix + "proj.", - keep_vars=keep_vars) - diff --git a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py deleted file mode 100644 index 3581446..0000000 --- a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py +++ /dev/null @@ -1,383 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention -from src.utils.model_loader import ModelLoader -from src.utils.no_grad import no_grad - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiTEncoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) - self.precompute_pos[(height, width)] = (pos_rope, pos_ape) - return (pos_rope, pos_ape) - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, mask=None): - B, _, H, W = x.shape - pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - s = self.s_embedder(x) - # s = s + pos_ape.to(s.dtype) - for i in range(self.num_blocks): - s = self.blocks[i](s, c, pos_rope, mask) - return None, s - - -class FlattenDiTDecoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) - self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def forward(self, x, t, y, s, mask=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) - c = torch.nn.functional.silu(t + y) - x = torch.cat([x, s], dim=-1) - x = self.x_embedder(x) - for i in range(self.num_blocks): - x = self.blocks[i](x, c, pos, None) - x = self.final_layer(x, c) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, - stride=self.patch_size) - return x - - - -class FlattenDiT(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - ModelLoader().load(encoder) - self.encoder = self.encoder.to(torch.bfloat16) - no_grad(self.encoder) - - def forward(self, x, t, y, s=None): - if s is None: - with torch.no_grad(): - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py deleted file mode 100644 index 733ce4a..0000000 --- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py +++ /dev/null @@ -1,447 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention -from src.utils.model_loader import ModelLoader -from src.utils.no_grad import no_grad - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -class ResBlock(nn.Module): - def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): - super().__init__() - self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) - self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) - self.norm1 = nn.GroupNorm(groups, dim) - self.norm2 = nn.GroupNorm(groups, dim) - self.embed_proj = nn.Linear(hidden_dim, dim) - - def forward(self, x, c): - c = self.embed_proj(c)[:, :, None, None] - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = torch.nn.functional.silu(x) - x = x * c - x = self.conv2(x) - x = self.norm2(x) - x = torch.nn.functional.silu(x) - return residual + x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiTEncoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) - self.precompute_pos[(height, width)] = (pos_rope, pos_ape) - return (pos_rope, pos_ape) - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, mask=None, classify_layer=None): - B, _, H, W = x.shape - pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - s = self.s_embedder(x) - # s = s + pos_ape.to(s.dtype) - classify_feats = [] - for i in range(self.num_blocks): - s = self.blocks[i](s, c, pos_rope, mask) - if classify_layer is not None and i < classify_layer: - classify_feats.append(s) - if i == classify_layer - 1: - return _, classify_feats - return None, s - - -class FlattenDiTDecoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_mid_blocks=18, - num_res_blocks=[1, 1, 1], - num_res_channels=[64, 384, 768], - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_mid_blocks = num_mid_blocks - self.num_res_blocks = num_res_blocks - self.num_res_channels = num_res_channels - self.patch_size = 2**(len(num_res_blocks)) - - self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) - self.t_embedder = TimestepEmbedder(hidden_size) - - self.down_res_blocks = nn.ModuleList() - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.down_res_blocks.append( - nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), - ) - self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = [] - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.up_res_blocks.append( - nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) - ) - self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) - - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) - ]) - - self.initialize_weights() - self.precompute_pos = dict() - self.weight_path = weight_path - self.load_ema = load_ema - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, s, mask=None): - B, _, H, W = x.shape - t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) - y = self.y_embedder(y).view(B, self.hidden_size) - s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) - c = torch.nn.functional.silu(t + y) - - residual = [] - for i, block in enumerate(self.down_res_blocks): - if isinstance(block, nn.Conv2d): - residual.append(x) - x = block(x) - else: - x = block(x, c) - - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = x.view(B, self.hidden_size, -1).transpose(1, 2) - mid_c = torch.nn.functional.silu(t[:, None, :] + s) - for i in range(self.num_mid_blocks): - x = self.blocks[i](x, mid_c, pos, None) - x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) - - residual[0] = 0.0 - for i, block in enumerate(self.up_res_blocks): - if isinstance(block, nn.ConvTranspose2d): - x = block(x) + residual.pop() - else: - x = block(x, c) - return x - - -class FlattenDiT(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - ModelLoader().load(encoder) - self.encoder = self.encoder.to(torch.bfloat16) - no_grad(self.encoder) - - def forward(self, x, t, y, s=None, classify_layer=None): - if s is None: - _, s = self.encoder(x, t, y, classify_layer=classify_layer) - if classify_layer is not None: - return None, s - x = self.decoder(x, t, y, s) - return x, s - -class FlattenDiT_jointtraining(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - - def forward(self, x, t, y, s=None): - if s is None: - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py deleted file mode 100644 index 6e9adbc..0000000 --- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py +++ /dev/null @@ -1,448 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention -from src.utils.model_loader import ModelLoader -from src.utils.no_grad import no_grad - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -class ResBlock(nn.Module): - def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): - super().__init__() - self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) - self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) - self.norm1 = nn.GroupNorm(groups, dim) - self.norm2 = nn.GroupNorm(groups, dim) - self.embed_proj = nn.Linear(hidden_dim, dim) - - def forward(self, x, c): - c = self.embed_proj(c)[:, :, None, None] - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = torch.nn.functional.silu(x) - x = x * c - x = self.conv2(x) - x = self.norm2(x) - x = torch.nn.functional.silu(x) - return residual + x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiTEncoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) - self.precompute_pos[(height, width)] = (pos_rope, pos_ape) - return (pos_rope, pos_ape) - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, mask=None, classify_layer=None): - B, _, H, W = x.shape - pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - s = self.s_embedder(x) - # s = s + pos_ape.to(s.dtype) - classify_feats = [] - for i in range(self.num_blocks): - s = self.blocks[i](s, c, pos_rope, mask) - if classify_layer is not None and i < classify_layer: - classify_feats.append(s) - if i == classify_layer - 1: - return _, classify_feats - return None, s - - -class FlattenDiTDecoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_mid_blocks=18, - num_res_blocks=[1, 1, 1], - num_res_channels=[64, 384, 768], - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_mid_blocks = num_mid_blocks - self.num_res_blocks = num_res_blocks - self.num_res_channels = num_res_channels - self.patch_size = 2**(len(num_res_blocks)) - - self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) - self.t_embedder = TimestepEmbedder(hidden_size) - - self.down_res_blocks = nn.ModuleList() - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.down_res_blocks.append( - nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), - ) - self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = [] - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.up_res_blocks.append( - nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) - ) - self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) - - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) - ]) - - self.initialize_weights() - self.precompute_pos = dict() - self.weight_path = weight_path - self.load_ema = load_ema - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, s, mask=None): - B, _, H, W = x.shape - t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) - y = self.y_embedder(y).view(B, self.hidden_size) - s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) - c = torch.nn.functional.silu(t + y) - - residual = [] - for i, block in enumerate(self.down_res_blocks): - if isinstance(block, nn.Conv2d): - residual.append(x) - x = block(x) - else: - x = block(x, c) - - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = x.view(B, self.hidden_size, -1).transpose(1, 2) - mid_c = torch.nn.functional.silu(t[:, None, :] + s) - for i in range(self.num_mid_blocks): - x = self.blocks[i](x, mid_c, pos, None) - x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) - - residual[0] = 0.0 - for i, block in enumerate(self.up_res_blocks): - if isinstance(block, nn.ConvTranspose2d): - x = block(x) + residual.pop() - else: - x = block(x, c) - return x - - -class FlattenDiT(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - ModelLoader().load(encoder, "encoder.") - ModelLoader().load(decoder, "decoder.") - self.encoder = self.encoder.to(torch.bfloat16) - no_grad(self.encoder) - - def forward(self, x, t, y, s=None, classify_layer=None): - if s is None: - _, s = self.encoder(x, t, y, classify_layer=classify_layer) - if classify_layer is not None: - return None, s - x = self.decoder(x, t, y, s) - return x, s - -class FlattenDiT_jointtraining(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - - def forward(self, x, t, y, s=None): - if s is None: - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py deleted file mode 100644 index 537078a..0000000 --- a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py +++ /dev/null @@ -1,464 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention -from src.utils.model_loader import ModelLoader -from src.utils.no_grad import no_grad - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -class ResBlock(nn.Module): - def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): - super().__init__() - self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) - self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) - self.norm1 = nn.GroupNorm(groups, dim) - self.norm2 = nn.GroupNorm(groups, dim) - - def forward(self, x): - residual = x - x = self.conv1(x) - x = self.norm1(x) - x = torch.nn.functional.silu(x) - x = self.conv2(x) - x = self.norm2(x) - x = torch.nn.functional.silu(x) - return residual + x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiTEncoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) - self.precompute_pos[(height, width)] = (pos_rope, pos_ape) - return (pos_rope, pos_ape) - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, mask=None, classify_layer=None): - B, _, H, W = x.shape - pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - s = self.s_embedder(x) - # s = s + pos_ape.to(s.dtype) - classify_feats = [] - for i in range(self.num_blocks): - s = self.blocks[i](s, c, pos_rope, mask) - if classify_layer is not None and i < classify_layer: - classify_feats.append(s) - if i == classify_layer - 1: - return _, classify_feats - return None, s - - -class FlattenDiTDecoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_mid_blocks=18, - num_res_blocks=[1, 1, 1], - num_res_channels=[64, 384, 768], - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_mid_blocks = num_mid_blocks - self.num_res_blocks = num_res_blocks - self.num_res_channels = num_res_channels - self.patch_size = 2**(len(num_res_blocks)) - - self.t_embedder = TimestepEmbedder(hidden_size) - - self.down_res_blocks = nn.ModuleList() - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.down_res_blocks.append( - nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), - ) - self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = [] - previous_channel = self.in_channels - for num, channels in zip(num_res_blocks, num_res_channels): - self.up_res_blocks.append( - nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) - ) - self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) - previous_channel = channels - self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) - - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) - ]) - - self.initialize_weights() - self.precompute_pos = dict() - self.weight_path = weight_path - self.load_ema = load_ema - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - # Zero-out adaLN modulation layers in SiT blocks: - for block in self.blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - for block in self.down_res_blocks: - if isinstance(block, ResBlock): - nn.init.constant_(block.conv1.weight, 0) - nn.init.constant_(block.conv1.bias, 0) - nn.init.constant_(block.norm1.weight, 0) - nn.init.constant_(block.norm2.weight, 0) - nn.init.constant_(block.conv2.weight, 0) - nn.init.constant_(block.conv2.bias, 0) - - for block in self.up_res_blocks: - if isinstance(block, ResBlock): - nn.init.constant_(block.conv1.weight, 0) - nn.init.constant_(block.conv1.bias, 0) - nn.init.constant_(block.norm1.weight, 0) - nn.init.constant_(block.norm2.weight, 0) - nn.init.constant_(block.conv2.weight, 0) - nn.init.constant_(block.conv2.bias, 0) - - def forward(self, x, t, y, s, mask=None): - B, _, H, W = x.shape - t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) - s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) - - residual = [] - for i, block in enumerate(self.down_res_blocks): - if isinstance(block, nn.Conv2d): - residual.append(x) - x = block(x) - else: - x = block(x) - - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = x.view(B, self.hidden_size, -1).transpose(1, 2) - mid_c = torch.nn.functional.silu(t[:, None, :] + s) - for i in range(self.num_mid_blocks): - x = self.blocks[i](x, mid_c, pos, None) - x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) - - residual[0] = 0.0 - for i, block in enumerate(self.up_res_blocks): - if isinstance(block, nn.ConvTranspose2d): - x = block(x) + residual.pop() - else: - x = block(x) - return x - - -class FlattenDiT(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - ModelLoader().load(encoder, "encoder.") - ModelLoader().load(decoder, "decoder.") - self.encoder = self.encoder.to(torch.bfloat16) - no_grad(self.encoder) - - def forward(self, x, t, y, s=None, classify_layer=None): - if s is None: - _, s = self.encoder(x, t, y, classify_layer=classify_layer) - if classify_layer is not None: - return None, s - x = self.decoder(x, t, y, s) - return x, s - -class FlattenDiT_jointtraining(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - - def forward(self, x, t, y, s=None): - if s is None: - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - diff --git a/src/models/denoiser/condit_dit.py b/src/models/denoiser/condit_dit.py deleted file mode 100644 index 48d6b0e..0000000 --- a/src/models/denoiser/condit_dit.py +++ /dev/null @@ -1,274 +0,0 @@ -import torch -import torch.nn as nn -import math - -from numba.cuda.cudadrv.devicearray import lru_cache -from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - -from torch.nn.attention import SDPBackend, sdpa_kernel - -flex_attention = torch.compile(flex_attention) - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = False, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = self.norm(x) - x = modulate(x, shift, scale) - x = self.linear(x) - return x - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - self.fc1 = nn.Linear(dim, hidden_dim) - self.fc2 = nn.Linear(hidden_dim, dim) - self.act = nn.GELU(approximate="tanh") - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.fc2(x) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16): - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - x_pos = x_pos.reshape(-1) - y_pos = y_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1) - return freqs_cis - - -class Attention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = True, - qk_norm: bool = False, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = nn.LayerNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - # import pdb; pdb.set_trace() - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q).to(q.dtype) - k = self.k_norm(k).to(k.dtype) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - # x = flex_attention(q, k, v, block_mask=mask) - # with sdpa_kernel(SDPBackend.FLASH_ATTENTION): - x = scaled_dot_product_attention(q, k, v, mask) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class DiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) - self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False) - self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class ConDiT(nn.Module): - def __init__( - self, - in_channels=4, - out_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = out_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) - self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2) - self.num_cond_blocks = num_cond_blocks - - - self.weight_path = weight_path - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - - @lru_cache - def fetch_pos(self, height, width, device): - pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...] - return pos - - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def forward(self, x, t, y, s=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H, W, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - - if s is None: - # semantic encoder - s = self.s_embedder(x) + pos - c = nn.functional.silu(t + y) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s diff --git a/src/models/denoiser/flatten_condit_dit_fixt.py b/src/models/denoiser/decouple_improved_dit.py similarity index 95% rename from src/models/denoiser/flatten_condit_dit_fixt.py rename to src/models/denoiser/decouple_improved_dit.py index 15557f3..f20115c 100644 --- a/src/models/denoiser/flatten_condit_dit_fixt.py +++ b/src/models/denoiser/decouple_improved_dit.py @@ -194,7 +194,7 @@ class RAttention(nn.Module): -class FlattenDiTBlock(nn.Module): +class DDTBlock(nn.Module): def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): super().__init__() self.norm1 = RMSNorm(hidden_size, eps=1e-6) @@ -213,14 +213,14 @@ class FlattenDiTBlock(nn.Module): return x -class FlattenConDiT(nn.Module): +class DDT(nn.Module): def __init__( self, in_channels=4, num_groups=12, hidden_size=1152, num_blocks=18, - num_cond_blocks=4, + num_encoder_blocks=4, patch_size=2, num_classes=1000, learn_sigma=True, @@ -236,7 +236,7 @@ class FlattenConDiT(nn.Module): self.hidden_size = hidden_size self.num_groups = num_groups self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks + self.num_encoder_blocks = num_encoder_blocks self.patch_size = patch_size self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) @@ -249,7 +249,7 @@ class FlattenConDiT(nn.Module): self.load_ema = load_ema self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) ]) self.initialize_weights() self.precompute_pos = dict() @@ -280,11 +280,6 @@ class FlattenConDiT(nn.Module): nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) @@ -301,12 +296,12 @@ class FlattenConDiT(nn.Module): c = nn.functional.silu(t + y) if s is None: s = self.s_embedder(x) - for i in range(self.num_cond_blocks): + for i in range(self.num_encoder_blocks): s = self.blocks[i](s, c, pos, mask) s = nn.functional.silu(t + s) x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): + for i in range(self.num_encoder_blocks, self.num_blocks): x = self.blocks[i](x, s, pos, None) x = self.final_layer(x, s) x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) diff --git a/src/models/denoiser/flatten_condit_catdit_fixt.py b/src/models/denoiser/flatten_condit_catdit_fixt.py deleted file mode 100644 index 22a0fd5..0000000 --- a/src/models/denoiser/flatten_condit_catdit_fixt.py +++ /dev/null @@ -1,314 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None, mask=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.s_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, mask) - # s = nn.functional.silu(t + s) - s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) - x = torch.cat((x, s), dim=-1) - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, c, pos, None) - x = self.final_layer(x, c) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_conv_fixt.py b/src/models/denoiser/flatten_condit_conv_fixt.py deleted file mode 100644 index 219db4c..0000000 --- a/src/models/denoiser/flatten_condit_conv_fixt.py +++ /dev/null @@ -1,340 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - -class FlattenConvBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3): - super().__init__() - self.hidden_size = hidden_size - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - attn_x = modulate(self.norm1(x), shift_msa, scale_msa) - attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() - attn_x = self.attn(attn_x) - attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) - x = x + gate_msa * attn_x - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - kernel_size=3, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([]) - for i in range(self.num_cond_blocks): - self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) - for i in range(self.num_blocks-self.num_cond_blocks): - self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size)) - - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.s_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, None) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_convnext_fixt.py b/src/models/denoiser/flatten_condit_convnext_fixt.py deleted file mode 100644 index cf9c214..0000000 --- a/src/models/denoiser/flatten_condit_convnext_fixt.py +++ /dev/null @@ -1,339 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - -class FlattenConvBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.hidden_size = hidden_size - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - attn_x = modulate(self.norm1(x), shift_msa, scale_msa) - attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() - attn_x = self.attn(attn_x) - attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) - x = x + gate_msa * attn_x - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([]) - for i in range(self.num_cond_blocks): - self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) - for i in range(self.num_blocks-self.num_cond_blocks): - self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups)) - - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.s_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, None) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_dit_norm_fixt.py b/src/models/denoiser/flatten_condit_dit_norm_fixt.py deleted file mode 100644 index 28034e3..0000000 --- a/src/models/denoiser/flatten_condit_dit_norm_fixt.py +++ /dev/null @@ -1,314 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None, mask=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.s_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, mask) - s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py deleted file mode 100644 index 9a5e4fd..0000000 --- a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py +++ /dev/null @@ -1,429 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention -from src.utils.model_loader import ModelLoader -from src.utils.no_grad import no_grad - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiTEncoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) - self.precompute_pos[(height, width)] = (pos_rope, pos_ape) - return (pos_rope, pos_ape) - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - def forward(self, x, t, y, mask=None): - B, _, H, W = x.shape - pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - s = self.s_embedder(x) - # s = s + pos_ape.to(s.dtype) - for i in range(self.num_blocks): - s = self.blocks[i](s, c, pos_rope, mask) - return None, s - - -class FlattenDiTDecoder(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)] - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def forward(self, x, t, y, s, mask=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - # s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) - s = torch.nn.functional.silu(t + s) - x = self.x_embedder(x) - for i in range(self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, - stride=self.patch_size) - return x - - - -class FlattenDiT(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - joint_training=False, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - - ModelLoader().load(encoder) - if not joint_training: - self.encoder = self.encoder.to(torch.bfloat16) - no_grad(self.encoder) - - self.joint_training = joint_training - - def forward(self, x, t, y, s=None): - if s is None: - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - -class FlattenDiTScalingEncoder(nn.Module): - def __init__( - self, - encoder:FlattenDiTEncoder, - decoder:FlattenDiTDecoder, - ): - super().__init__() - self.encoder = encoder - self.decoder = decoder - no_grad(self.decoder) - - if self.encoder.weight_path: - weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu')) - if self.encoder.load_ema: - prefix = "ema_denoiser." - else: - prefix = "denoiser." - for k, v in self.encoder.state_dict().items(): - try: - v.copy_(weight["state_dict"][prefix+k]) - except: - print(f"Failed to copy {prefix+k} to denoiser weight") - - if self.decoder.weight_path: - weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu')) - if self.decoder.load_ema: - prefix = "ema_denoiser." - else: - prefix = "denoiser." - for k, v in self.decoder.state_dict().items(): - if "blocks." in k: - blockid = int(k.split("blocks.")[-1][0]) - k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}") - try: - v.copy_(weight["state_dict"][prefix+k]) - except: - print(f"Failed to copy {prefix+k} to denoiser weight") - self.decoder = decoder.to(torch.bfloat16) - - def forward(self, x, t, y, s=None): - if s is None: - _, s = self.encoder(x, t, y) - x = self.decoder(x, t, y, s) - return x, s - diff --git a/src/models/denoiser/flatten_condit_mlp_fixt.py b/src/models/denoiser/flatten_condit_mlp_fixt.py deleted file mode 100644 index 40735e4..0000000 --- a/src/models/denoiser/flatten_condit_mlp_fixt.py +++ /dev/null @@ -1,334 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - -class FlattenMLPBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = FeedForward(hidden_size, mlp_hidden_dim) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([]) - for i in range(self.num_cond_blocks): - self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) - for i in range(self.num_blocks-self.num_cond_blocks): - self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups)) - - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.s_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, None) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py deleted file mode 100644 index bcf3315..0000000 --- a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py +++ /dev/null @@ -1,321 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.s_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.s_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None, mask=None): - B, _, H, W = x.shape - pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device) - s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2) - s = self.s_embedder(s) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos_s, mask) - s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size) - s = torch.permute(s, (0, 3, 1, 2)) - s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False) - s = torch.permute(s, (0, 2, 3, 1)) - s = s.view(B, -1, self.hidden_size) - s = nn.functional.silu(t + s) - - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos_x, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt_xvout.py b/src/models/denoiser/flatten_dit_fixt_xvout.py deleted file mode 100644 index 4df3393..0000000 --- a/src/models/denoiser/flatten_dit_fixt_xvout.py +++ /dev/null @@ -1,311 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return (self.weight * hidden_states).to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device, dtype): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device, dtype) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def forward(self, x, t, y, masks=None): - if masks is None: - masks = [None, ]*self.num_blocks - if isinstance(masks, torch.Tensor): - masks = masks.unbind(0) - if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: - masks = masks + [None]*(self.num_blocks-len(masks)) - - B, _, H, W = x.shape - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - x = self.x_embedder(x) - pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) - B, L, C = x.shape - t = self.t_embedder(t.view(-1)).view(B, -1, C) - y = self.y_embedder(y).view(B, 1, C) - condition = nn.functional.silu(t + y) - for i, block in enumerate(self.blocks): - x = block(x, condition, pos, masks[i]) - x = self.final_layer(x, condition) - x0, v = x.chunk(2, dim=-1) - x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - if self.training: - return v, x0 - else: - return v \ No newline at end of file diff --git a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py deleted file mode 100644 index 4e570b0..0000000 --- a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py +++ /dev/null @@ -1,308 +0,0 @@ -import functools -from typing import Tuple -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from torch.nn.modules.module import T - -# from torch.nn.attention.flex_attention import flex_attention, create_block_mask -from torch.nn.functional import scaled_dot_product_attention - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - -class Embed(nn.Module): - def __init__( - self, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Linear(in_chans, embed_dim, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class TimestepEmbedder(nn.Module): - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10): - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half - ) - args = t[..., None].float() * freqs[None, ...] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - -class LabelEmbedder(nn.Module): - def __init__(self, num_classes, hidden_size): - super().__init__() - self.embedding_table = nn.Embedding(num_classes, hidden_size) - self.num_classes = num_classes - - def forward(self, labels,): - embeddings = self.embedding_table(labels) - return embeddings - -class FinalLayer(nn.Module): - def __init__(self, hidden_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(hidden_size, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 2*hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - return x - -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): - # assert H * H == end - # flat_patch_pos = torch.linspace(-1, 1, end) # N = end - x_pos = torch.linspace(0, scale, width) - y_pos = torch.linspace(0, scale, height) - y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") - y_pos = y_pos.reshape(-1) - x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 - x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) - y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - freqs_cis = freqs_cis[None, :, None, :] - # xq : B N H Hc - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -class RAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_norm: bool = True, - attn_drop: float = 0., - proj_drop: float = 0., - norm_layer: nn.Module = RMSNorm, - ) -> None: - super().__init__() - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc - q = self.q_norm(q) - k = self.k_norm(k) - q, k = apply_rotary_emb(q, k, freqs_cis=pos) - q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc - k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc - v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() - - x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - - -class FlattenDiTBlock(nn.Module): - def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c, pos, mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) - x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) - x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - -class FlattenConDiT(nn.Module): - def __init__( - self, - in_channels=4, - num_groups=12, - hidden_size=1152, - num_blocks=18, - num_cond_blocks=4, - patch_size=2, - num_classes=1000, - learn_sigma=True, - deep_supervision=0, - weight_path=None, - load_ema=False, - ): - super().__init__() - self.deep_supervision = deep_supervision - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels - self.hidden_size = hidden_size - self.num_groups = num_groups - self.num_blocks = num_blocks - self.num_cond_blocks = num_cond_blocks - self.patch_size = patch_size - self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) - - self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) - - self.weight_path = weight_path - - self.load_ema = load_ema - self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) - ]) - self.initialize_weights() - self.precompute_pos = dict() - - def fetch_pos(self, height, width, device): - if (height, width) in self.precompute_pos: - return self.precompute_pos[(height, width)].to(device) - else: - pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) - self.precompute_pos[(height, width)] = pos - return pos - - def initialize_weights(self): - # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): - w = self.x_embedder.proj.weight.data - nn.init.xavier_uniform_(w.view([w.shape[0], -1])) - nn.init.constant_(self.x_embedder.proj.bias, 0) - - - # Initialize label embedding table: - nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) - - # Initialize timestep embedding MLP: - nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) - nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers: - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - - def forward(self, x, t, y, s=None, mask=None): - B, _, H, W = x.shape - pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) - x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) - y = self.y_embedder(y).view(B, 1, self.hidden_size) - c = nn.functional.silu(t + y) - if s is None: - s = self.x_embedder(x) - for i in range(self.num_cond_blocks): - s = self.blocks[i](s, c, pos, mask) - s = nn.functional.silu(t + s) - - x = self.x_embedder(x) - for i in range(self.num_cond_blocks, self.num_blocks): - x = self.blocks[i](x, s, pos, None) - x = self.final_layer(x, s) - x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) - return x, s \ No newline at end of file diff --git a/src/models/denoiser/flowdcn.py b/src/models/denoiser/flowdcn.py deleted file mode 100644 index 92e2237..0000000 --- a/src/models/denoiser/flowdcn.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -import torch.nn as nn -import math - -from torch.nn.init import zeros_ -from src.models.denoiser.base_model import BaseModel -from src.ops.triton_kernels.function import DCNFunction - -def modulate(x, shift, scale): - return x * (1 + scale[:, None, None]) + shift[:, None, None] - -class PatchEmbed(nn.Module): - def __init__( - self, - patch_size: int = 16, - in_chans: int = 3, - embed_dim: int = 768, - norm_layer = None, - bias: bool = True, - ): - super().__init__() - self.patch_size = patch_size - self.in_chans = in_chans - self.embed_dim = embed_dim - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - def forward(self, x): - x = self.proj(x) - x = self.norm(x) - return x - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - ): - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - def forward(self, x): - b, h, w, c = x.shape - x = x.view(b, h*w, c) - x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) - x = x.view(b, h, w, c) - return x - - -class MultiScaleDCN(nn.Module): - def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): - super().__init__() - self.in_channels = in_channels - self.groups = groups - self.channels = channels - self.kernels = kernels - self.v = nn.Linear(in_channels, groups * channels, bias=True) - self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) - self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) - self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) - self.out = nn.Linear(groups * channels, in_channels) - self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) - self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) - self.max_scale = 6 - self._init_weights() - def _init_weights(self): - zeros_(self.qk_deformables.weight.data) - zeros_(self.qk_scales.weight.data) - zeros_(self.qk_deformables.bias.data) - zeros_(self.qk_weights.weight.data) - zeros_(self.v.bias.data) - zeros_(self.out.bias.data) - num_prior = int(self.kernels ** 0.5) - dx = torch.linspace(-1, 1, num_prior, device="cuda") - dy = torch.linspace(-1, 1, num_prior, device="cuda") - dxy = torch.meshgrid([dx, dy], indexing="xy") - dxy = torch.stack(dxy, dim=-1) - dxy = dxy.view(-1, 2) - self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy - for i in range(self.groups): - scale = (i+1)/self.groups - 0.0001 - inv_scale = math.log((scale)/(1-scale)) - self.deformables_scale.data[..., i, :, :] = inv_scale - def forward(self, x): - B, H, W, _ = x.shape - v = self.v(x).view(B, H, W, self.groups, self.channels) - deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) - scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale - deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale - weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) - out = DCNFunction.apply(v, deformables, weights) - out = out.view(B, H, W, -1) - out = self.out(out) - return out - -class FlowDCNBlock(nn.Module): - def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True): - super().__init__() - self.norm1 = RMSNorm(hidden_size, eps=1e-6) - self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass) - self.norm2 = RMSNorm(hidden_size, eps=1e-6) - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.mlp = FeedForward(hidden_size, mlp_hidden_dim) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 6 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) - x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) - x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) - return x - - - - -class FlowDCN(BaseModel): - def __init__(self, deformable_biass=True, *args, **kwargs): - super().__init__(*args, **kwargs) - self.blocks = nn.ModuleList([ - FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks) - ]) - self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True) - self.initialize_weights() - - def forward(self, x, t, y): - batch_size, _, height, width = x.shape[0] - x = self.x_embedder(x) # (N, D, h, w) - x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1) - t = self.t_embedder(t) # (N, D) - y = self.y_embedder(y, self.training) # (N, D) - c = t + y # (N, D) - B, L, C = x.shape - x = x.view(B, height//self.patch_size, width//self.patch_size, C) - for block in self.blocks: - x = block(x, c) # (N, T, D) - x = x.view(B, L, C) - x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) - x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size) - if self.learn_sigma: - x, _ = torch.split(x, self.out_channels // 2, dim=1) - return x \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt.py b/src/models/denoiser/improved_dit.py similarity index 96% rename from src/models/denoiser/flatten_dit_fixt.py rename to src/models/denoiser/improved_dit.py index 9412d6e..99e2f5a 100644 --- a/src/models/denoiser/flatten_dit_fixt.py +++ b/src/models/denoiser/improved_dit.py @@ -194,7 +194,7 @@ class RAttention(nn.Module): -class FlattenDiTBlock(nn.Module): +class DiTBlock(nn.Module): def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): super().__init__() self.norm1 = RMSNorm(hidden_size, eps=1e-6) @@ -213,7 +213,7 @@ class FlattenDiTBlock(nn.Module): return x -class FlattenDiT(nn.Module): +class DiT(nn.Module): def __init__( self, in_channels=4, @@ -246,7 +246,7 @@ class FlattenDiT(nn.Module): self.load_ema = load_ema self.blocks = nn.ModuleList([ - FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) ]) self.initialize_weights() self.precompute_pos = dict() @@ -272,11 +272,6 @@ class FlattenDiT(nn.Module): nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) - # # Zero-out adaLN modulation layers in SiT blocks: - # for block in self.blocks: - # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - # Zero-out output layers: nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) diff --git a/src/ops/cuda_kernels/backward.cu b/src/ops/cuda_kernels/backward.cu deleted file mode 100644 index 2e85d86..0000000 --- a/src/ops/cuda_kernels/backward.cu +++ /dev/null @@ -1,346 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace cg = cooperative_groups; - -template -__device__ __always_inline int toInt(scalar_t val); - -template<> -__device__ __always_inline int toInt(float val){ - return static_cast(val); -} -template<> -__device__ __always_inline int toInt(half val){ - return __half2int_rz(val); -} - -template -__device__ __always_inline scalar_t fromInt(int val); - -template<> -__device__ __always_inline float fromInt(int val){ - return static_cast(val); -} - -template<> -__device__ __always_inline half fromInt(int val){ - return __int2half_rz(val); -} - -template -__device__ __always_inline scalar_t constVal(float val); - -template<> -__device__ __always_inline float constVal(float val) { - return (float)val; -} - -template<> -__device__ __always_inline half constVal(float val) { - return __float2half(val); // Using float to half conversion -} -template<> -__device__ __always_inline nv_bfloat16 constVal(float val){ - return __float2bfloat16(val); -} - - - - - -// B, H, W, C, BLOCK_DIM must be multiple of C -template -__global__ void dcn_backward_pipeline_kernel( - const int H, - const int W, - const int G, - const int K, - const int C, - scalar_t* ptr_values, - scalar_t* ptr_deformables, - scalar_t* ptr_weights, - scalar_t* ptr_grad_out, - scalar_t* ptr_grad_values, - scalar_t* ptr_grad_deformables, - scalar_t* ptr_grad_weights -) { - auto block = cg::this_thread_block(); - auto self_thread = cg::this_thread(); - auto tile_threads = cg::tiled_partition(block); - int local_thread_id = block.thread_rank(); - int local_tile_id = tile_threads.meta_group_rank(); - int num_local_tiles = tile_threads.meta_group_size(); - int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id; - - extern __shared__ int shm[]; - auto GradBuffer = reinterpret_cast(shm); - scalar_t* Buffer = reinterpret_cast(shm) + num_local_tiles*C; - if(global_tile_id >= H*W*G) return; - - int bid = block.group_index().y; - int gid = global_tile_id % G; - int wid = global_tile_id / G % W; - int hid = global_tile_id / G / W; - int globale_offset = bid*H*W*G*C + global_tile_id*C; - cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C); - - int shared_offset[pipeline_stages]; - for (int s = 0; s < pipeline_stages; ++s) { - shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4); - } - - auto pipeline = cuda::make_pipeline(); - const int num_tiles_per_thread = C/TILE_C/TILE_THREADS; - - for(int k=0; k(wid); - y = ptr_deformables[offset*2 + 1] + fromInt(hid); -// x = fromInt(wid); -// y = fromInt(hid); - weight = ptr_weights[offset]; - } - tile_threads.sync(); - x = tile_threads.shfl(x, 0); - y = tile_threads.shfl(y, 0); - weight = tile_threads.shfl(weight, 0); - - int floor_x = toInt(x); - int floor_y = toInt(y); - int ceil_x = floor_x + 1; - int ceil_y = floor_y + 1; - - - scalar_t dodx = constVal(0.0f); - scalar_t dody = constVal(0.0f); - scalar_t dodw = constVal(0.0f); - - int start_c = tile_threads.thread_rank() * (C / TILE_THREADS); - - bool tl_flag = (floor_x >=0) and (floor_x =0) and (floor_y=0) and (ceil_x =0) and (floor_y=0) and (floor_x =0) and (ceil_y=0) and (ceil_x =0) and (ceil_y(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; - dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; - dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; - dodw = dodw + (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; - dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1]; - dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; - { - vec2_t vtl_di; - vtl_di.x = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; - vtl_di.y = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1]; - atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di); - } - } - - - if(tr_flag){ - // tr - dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; - dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; - dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; - dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1]; - dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1]; - dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1]; - { - vec2_t vtr_di; - vtr_di.x = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; - vtr_di.y = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1]; - atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di); - } - } - - if(bl_flag){ - // bl - dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; - dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; - dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; - dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; - dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; - dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; - { - vec2_t vbl_di; - vbl_di.x = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; - vbl_di.y = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; - atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di); - } - } - - - if(br_flag){ - // tr - dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; - dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; - dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; - dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; - dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; - dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; - { - vec2_t vbr_di; - vbr_di.x = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; - vbr_di.y = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; - atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di); - } - } - } - pipeline.consumer_release(); - } - for (int i = TILE_THREADS>>1; i > 0; i/=2) { - dodx = dodx + tile_threads.shfl_down(dodx, i); - dody = dody + tile_threads.shfl_down(dody, i); - dodw = dodw + tile_threads.shfl_down(dodw, i); - } - if (tile_threads.thread_rank() == 0) { - cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline); - cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline); - cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline); - } - } -} - - -using namespace torch; -template -void backward(const int B, - const int H, - const int W, - const int G, - const int K, - const int C, - torch::Tensor values, - torch::Tensor deformables, - torch::Tensor weights, - torch::Tensor grad_out, - torch::Tensor grad_values, - torch::Tensor grad_deformables, - torch::Tensor grad_weights -) { - int num_local_tiles =(THREADS/TILE_THREADS); - int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles; - dim3 launch_threads_per_block(THREADS); - dim3 launch_blocks(num_global_tiles, B); - - int deformable_shm_size = 0; - int grad_out_shm_size = num_local_tiles*C; - int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS); - - int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size; -// printf("shm_size: %d\n", shm_size/512); -// printf("pipeline_size: %d\n", pipeline_shm_size/512); -// printf("grad_out_size: %d\n", grad_out_shm_size/512); - - - switch (values.type().scalarType()) { - case at::ScalarType::Half: - return dcn_backward_pipeline_kernel<<>>( - H, W, G, K, C, - reinterpret_cast(values.data_ptr()), - reinterpret_cast(deformables.data_ptr()), - reinterpret_cast(weights.data_ptr()), - reinterpret_cast(grad_out.data_ptr()), - reinterpret_cast(grad_values.data_ptr()), - reinterpret_cast(grad_deformables.data_ptr()), - reinterpret_cast(grad_weights.data_ptr()) - ); -// case at::ScalarType::BFloat16: -// return dcn_backward_pipeline_kernel<<>>( -// H, W, G, K, C, -// reinterpret_cast(values.data_ptr()), -// reinterpret_cast(deformables.data_ptr()), -// reinterpret_cast(weights.data_ptr()), -// reinterpret_cast(grad_out.data_ptr()), -// reinterpret_cast(grad_values.data_ptr()), -// reinterpret_cast(grad_deformables.data_ptr()), -// reinterpret_cast(grad_weights.data_ptr()) -// ); - default: - printf("running error"); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, ""); - m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, ""); - m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, ""); - m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, ""); - m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, ""); - m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, ""); - m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, ""); - m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, ""); - m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, ""); - m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, ""); - m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, ""); - m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, ""); - m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, ""); - m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, ""); - m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, ""); -// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, ""); -// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, ""); -// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, ""); - - m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, ""); - m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, ""); - m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, ""); - m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, ""); -} diff --git a/src/ops/cuda_kernels/bak_forward.cu b/src/ops/cuda_kernels/bak_forward.cu deleted file mode 100644 index 00569f8..0000000 --- a/src/ops/cuda_kernels/bak_forward.cu +++ /dev/null @@ -1,289 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -template -__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ - #pragma unroll - for(int i=0; i -__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ - #pragma unroll - for(int i=0; i -__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ -#pragma unroll - for(int i=0; i -__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ -#pragma unroll - for(int i=0; i -__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ - int work_id = threadIdx.x; - int bid = blockIdx.z; - int gid = blockIdx.y; - int G = gridDim.y; - int c_blockid = blockIdx.x; - int work_load = (H*W/blockDim.x); - - __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM - // __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM - math_t register_bufferA[BLOCK_DIM] = {0}; - int base_c = c_blockid*BLOCK_DIM; - - int num_transfers = BLOCK_DIM; -#pragma unroll - for(int i=0; i(register_bufferA, 1, BLOCK_DIM); - // loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); -#pragma unroll - for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); - } - // load top right - math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; - if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); - } - - // load bottom left - math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; - if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); - } - // load bottom right - math_t br_weight = (x - floor_x)*(y - floor_y)*weight; - if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); - } - - } - // loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); - int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); -#pragma unroll - for(int j=0; j -__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ - int work_id = threadIdx.x; - int bid = blockIdx.z; - int gid = blockIdx.y; - int G = gridDim.y; - int c_blockid = blockIdx.x; - int work_load = (H*W/blockDim.x); - - __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM - __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM - math_t register_bufferA[BLOCK_DIM] = {0}; - int base_c = c_blockid*BLOCK_DIM; - - int num_transfers = BLOCK_DIM/transfer_length; -#pragma unroll - for(int i=0; i(register_bufferA, 1, BLOCK_DIM); - loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); -#pragma unroll - for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); - } - // load top right - math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; - if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); - } - - // load bottom left - math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; - if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); - } - // load bottom right - math_t br_weight = (x - floor_x)*(y - floor_y)*weight; - if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); - } - - } - loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); - - } - - __syncthreads(); - -#pragma unroll - for(int i=0; i -void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { - - int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; - dim3 launch_threads_per_block(THREADS); - dim3 launch_blocks(NUM_C_BLOCK, G, B); - - switch (value.type().scalarType()) { - case at::ScalarType::Half: - return dcn_forward_kernel_16<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - case at::ScalarType::BFloat16: - return dcn_forward_kernel_16<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - case at::ScalarType::Float: - return dcn_forward_kernel<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - default: - printf("running error"); - } -} - - -// PyBind11 bindings -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); -//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); -} diff --git a/src/ops/cuda_kernels/forward.cu b/src/ops/cuda_kernels/forward.cu deleted file mode 100644 index ac18308..0000000 --- a/src/ops/cuda_kernels/forward.cu +++ /dev/null @@ -1,309 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -template -__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ - #pragma unroll - for(int i=0; i -__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ - #pragma unroll - for(int i=0; i -__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ -#pragma unroll - for(int i=0; i -__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ -#pragma unroll - for(int i=0; i -__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ - int work_id = threadIdx.x; - int bid = blockIdx.z; - int gid = blockIdx.y; - int G = gridDim.y; - int c_blockid = blockIdx.x; - int work_load = (H*W/blockDim.x); - - extern __shared__ int shm[]; - math_t* math_buffer = reinterpret_cast(shm); - - math_t register_bufferA[BLOCK_DIM] = {0}; - int base_c = c_blockid*BLOCK_DIM; - -#pragma unroll - for(int i=0; i(register_bufferA, 1, BLOCK_DIM); -#pragma unroll - for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); - } - // load top right - math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; - if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); - } - - // load bottom left - math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; - if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); - } - // load bottom right - math_t br_weight = (x - floor_x)*(y - floor_y)*weight; - if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); - } - - } - int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); -#pragma unroll - for(int j=0; j -__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ - int work_id = threadIdx.x; - int bid = blockIdx.z; - int gid = blockIdx.y; - int G = gridDim.y; - int c_blockid = blockIdx.x; - int work_load = (H*W/blockDim.x); - - extern __shared__ int shm[]; - math_t* math_buffer = reinterpret_cast(shm); - scalar_t* io_buffer = reinterpret_cast(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t); - math_t register_bufferA[BLOCK_DIM] = {0}; - int base_c = c_blockid*BLOCK_DIM; - - int num_transfers = BLOCK_DIM/transfer_length; -#pragma unroll - for(int i=0; i(register_bufferA, 1, BLOCK_DIM); - loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); -#pragma unroll - for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); - } - // load top right - math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; - if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); - } - - // load bottom left - math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; - if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); - } - // load bottom right - math_t br_weight = (x - floor_x)*(y - floor_y)*weight; - if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ - loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); - } - - } - loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); - - } - - __syncthreads(); - -#pragma unroll - for(int i=0; i -void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { - - int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; - dim3 launch_threads_per_block(THREADS); - dim3 launch_blocks(NUM_C_BLOCK, G, B); - int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half); - switch (value.type().scalarType()) { - case at::ScalarType::Half: - return dcn_forward_kernel_register<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - case at::ScalarType::Float: - return dcn_forward_kernel_register<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - default: - printf("running error"); - } -} - -template -void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { - - int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; - dim3 launch_threads_per_block(THREADS); - dim3 launch_blocks(NUM_C_BLOCK, G, B); - int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half); - switch (value.type().scalarType()) { - case at::ScalarType::Half: - return dcn_forward_kernel_pipeline<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - case at::ScalarType::BFloat16: - return dcn_forward_kernel_pipeline<<>>( - H, W, C, - value.data_ptr(), - deformables.data_ptr(), - weights.data_ptr(), - out.data_ptr()); - default: - printf("running error"); - } -} - -// PyBind11 bindings -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); -//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward"); -m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward"); -m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward"); -m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward"); -m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); -// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); -} diff --git a/src/ops/cuda_kernels/forward.py b/src/ops/cuda_kernels/forward.py deleted file mode 100644 index 4ea9c5e..0000000 --- a/src/ops/cuda_kernels/forward.py +++ /dev/null @@ -1,95 +0,0 @@ -import triton -import triton.language as tl - -@triton.autotune( - configs=[ - # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), - triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), - triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), - ], - key=['B', 'H', 'W', 'G', 'C', 'K'], -) -@triton.jit -def forward_kernel( - B: tl.constexpr, - H: tl.constexpr, # image_size_h - W: tl.constexpr, # image_size_w - G: tl.constexpr, # num_channels_per_group - C: tl.constexpr, # num_groups - K: tl.constexpr, # kernel size - input_ptr, # input features [B, H, W, G, C] - deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] - weights_ptr, # weights [B, H, W, G, K] - out_ptr, # out [B, H, W, G, C] - BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group -): - pid = tl.program_id(0) - wid = pid % W - hid = pid // W % H - gid = pid // (W * H) % G - bid = pid // (W * H * G) - - id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) - common_offset = bid*H*W*G + hid*W*G + wid*G + gid - batch_base = bid * H * W * G * C - - for block_base in tl.static_range(0, C, BLOCK_SIZE): - buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - block_offset = tl.arange(0, BLOCK_SIZE) + block_base - block_mask = (block_offset < C) & id_mask - for k in tl.static_range(K): - deformable_offset = (common_offset * K + k) * 2 - - x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid - y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid - - floor_x = x.to(tl.int32) - floor_y = y.to(tl.int32) - ceil_x = floor_x + 1 - ceil_y = floor_y + 1 - - # load top left - tl_weight = (ceil_x - x) * (ceil_y - y) - tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE - tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) - - # load top right - tr_weight = (x - floor_x) * (ceil_y - y) - tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE - tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) - # load bottom left - bl_weight = (ceil_x - x) * (y - floor_y) - bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE - bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) - # load bottom right - br_weight = (x - floor_x) * (y - floor_y) - br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE - br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) - - # load dynamic weight and mask - weights_offset = common_offset*K + k - weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) - - - - tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) - tl_block_input = tl_block_input * tl_weight - - # load top right - tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) - tr_block_input = tr_block_input * tr_weight - # load bottom left - bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) - bl_block_input = bl_block_input * bl_weight - # load bottom right - br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) - br_block_input = br_block_input * br_weight - - # sampled - sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input - - weighted_sampled_input = sampled_input * weight - buffer = buffer + weighted_sampled_input - # store to out_ptr - tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) - diff --git a/src/ops/cuda_kernels/function.py b/src/ops/cuda_kernels/function.py deleted file mode 100644 index 9d4bfad..0000000 --- a/src/ops/cuda_kernels/function.py +++ /dev/null @@ -1,126 +0,0 @@ -import time -import dcn_cuda_backward -import dcn_cuda_forward - -import math -import torch -from typing import Any -from torch.autograd import Function -from torch.autograd.function import once_differentiable -from torch.cuda.amp import custom_fwd, custom_bwd -from .forward import forward_kernel - - -class DCNFunction(Function): - BP_FUNCS = [ - dcn_cuda_backward.backward_p1_c2_tile16_thread128, - dcn_cuda_backward.backward_p1_c4_tile16_thread128, - dcn_cuda_backward.backward_p2_c2_tile16_thread128, - dcn_cuda_backward.backward_p1_c2_tile16_thread256, - dcn_cuda_backward.backward_p1_c4_tile16_thread256, - dcn_cuda_backward.backward_p2_c2_tile16_thread256, - dcn_cuda_backward.backward_p1_c2_tile16_thread384, - dcn_cuda_backward.backward_p1_c4_tile16_thread384, - dcn_cuda_backward.backward_p2_c2_tile16_thread384, - dcn_cuda_backward.backward_p1_c2_tile16_thread512, - dcn_cuda_backward.backward_p1_c4_tile16_thread512, - dcn_cuda_backward.backward_p2_c2_tile16_thread512, - dcn_cuda_backward.backward_p1_c2_tile16_thread768, - dcn_cuda_backward.backward_p1_c4_tile16_thread768, - dcn_cuda_backward.backward_p2_c2_tile16_thread768, - dcn_cuda_backward.backward_p1_c2_tile32_thread128, - dcn_cuda_backward.backward_p1_c2_tile32_thread256, - dcn_cuda_backward.backward_p1_c2_tile32_thread384, - dcn_cuda_backward.backward_p1_c2_tile32_thread512, - ] - FW_FUNCS = [ - dcn_cuda_forward.dcn_forward_l256_c4, - dcn_cuda_forward.dcn_forward_l256_c8, - dcn_cuda_forward.dcn_forward_l256_c16, - ] - BP_TABLES = dict() - FW_TABLES = dict() - - - @staticmethod - @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: Any, values, deformables, weights) -> Any: - B, H, W, G, C = values.shape - func = DCNFunction.find_fw_funcs(values, deformables, weights) - out = torch.zeros_like(values) - func(B, G, C, H, W, values, deformables, weights, out) - return out - - @staticmethod - def find_fw_funcs(values, deformables, weights): - B, H, W, G, C = values.shape - B, H, W, G, K = weights.shape - hash_value = 10000 * B + 100 * H + W + 1000 * G - if hash_value in DCNFunction.FW_TABLES.keys(): - return DCNFunction.FW_TABLES[hash_value] - print("missing") - candicate_func = None - min_t = 999.0 - outs = torch.zeros_like(values) - for func in DCNFunction.FW_FUNCS: - t = [] - for i in range(100): - torch.cuda.synchronize() - start_t = time.time() - func(B, G, C, H, W, values, deformables, weights, outs) - torch.cuda.synchronize() - t.append(time.time() - start_t) - t = t[-50:] - t = sum(t) / len(t) - if t < min_t: - min_t = t - DCNFunction.FW_TABLES[hash_value] = func - candicate_func = func - assert candicate_func is not None - print(candicate_func) - return candicate_func - @staticmethod - def find_bp_funcs(values, deformables, weights, grad_out): - B, H, W, G, C = values.shape - B, H, W, G, K = weights.shape - hash_value = 10000 * B + 100 * H + W + 1000 * G - if hash_value in DCNFunction.BP_TABLES.keys(): - return DCNFunction.BP_TABLES[hash_value] - print("missing") - candicate_func = None - min_t = 999.0 - grad_values = torch.zeros_like(values) - grad_deformables = torch.zeros_like(deformables) - grad_weights = torch.zeros_like(weights) - for func in DCNFunction.BP_FUNCS: - t = [] - for i in range(100): - torch.cuda.synchronize() - start_t = time.time() - func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) - torch.cuda.synchronize() - t.append(time.time() - start_t) - t = t[-50:] - t = sum(t) / len(t) - if t < min_t: - min_t = t - DCNFunction.BP_TABLES[hash_value] = func - candicate_func = func - assert candicate_func is not None - print(candicate_func) - return candicate_func - - @staticmethod - @once_differentiable - @custom_bwd - def backward(ctx: Any, *grad_outputs: Any) -> Any: - grad_out = grad_outputs[0] - values, deformables, weights = ctx.saved_tensors - B, H, W, G, C = values.shape - B, H, W, G, K = weights.shape - func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out) - grad_values = torch.zeros_like(values) - grad_deformables = torch.zeros_like(deformables) - grad_weights = torch.zeros_like(weights) - func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) - return grad_values, grad_deformables, grad_weights \ No newline at end of file diff --git a/src/ops/cuda_kernels/setup.py b/src/ops/cuda_kernels/setup.py deleted file mode 100644 index 34079d4..0000000 --- a/src/ops/cuda_kernels/setup.py +++ /dev/null @@ -1,59 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension - -setup( - name='dcn_cuda_forward', - ext_modules=[ - CUDAExtension('dcn_cuda_forward', ['./forward.cu',], - extra_compile_args={'cxx': [], 'nvcc': [ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "--use_fast_math", - "-O3", - ]} - ), - ], - cmdclass={ - 'build_ext': BuildExtension - } -) - -setup( - name='dcn_cuda_backward', - ext_modules=[ - CUDAExtension('dcn_cuda_backward', ['./backward.cu',], - extra_compile_args={'cxx': [], 'nvcc': [ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", - "--use_fast_math", - "-O3", - ]} - ), - ], - cmdclass={ - 'build_ext': BuildExtension - } -) - - -# setup( -# name='mycuda', -# ext_modules=[ -# CUDAExtension('mycuda', ['./backward.cu',], -# extra_compile_args={'cxx': [], 'nvcc': [ -# "-O3", -# "-DCUDA_HAS_FP16=1", -# "-D__CUDA_NO_HALF_OPERATORS__", -# "-D__CUDA_NO_HALF_CONVERSIONS__", -# "-D__CUDA_NO_HALF2_OPERATORS__", -# ]} -# ), -# ], -# cmdclass={ -# 'build_ext': BuildExtension -# } -# ) \ No newline at end of file diff --git a/src/ops/triton_kernels/__init__.py b/src/ops/triton_kernels/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ops/triton_kernels/backward.py b/src/ops/triton_kernels/backward.py deleted file mode 100644 index e886aa2..0000000 --- a/src/ops/triton_kernels/backward.py +++ /dev/null @@ -1,124 +0,0 @@ -import triton -import triton.language as tl - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), - triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), - ], - key=['B', 'H', 'W', 'G', 'C', 'K'], -) -@triton.jit -def backward_kernel( - B: tl.constexpr, - H: tl.constexpr, # image_size_h - W: tl.constexpr, # image_size_w - G: tl.constexpr, # num_groups - C: tl.constexpr, # num_channels_per_group - K: tl.constexpr, # kernel size - input_ptr, # input features [B, H, W, G, C] - deformable_ptr, # deformable offsets [B, H, W, G, K, 2] - weights_ptr, # weights [B, H, W, G, K] - grad_ptr, # out [B, H, W, G, C] - grad_input_ptr, # input features [B, H, W, G, C] - grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] - grad_weights_ptr, # weights [B, H, W, G, K] - BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group -): - - pid = tl.program_id(0) - wid = pid % W - hid = pid // W % H - gid = pid // (W * H) % G - bid = pid // (W * H * G) - - id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) - - common_offset = bid*H*W*G + hid*W*G + wid*G + gid - batch_base = bid * H * W * G * C - for k in tl.static_range(K): - # load dynamic weight and mask - weights_offset = common_offset*K + k - weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) - dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) - dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) - dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) - deformable_offset = (common_offset * K + k)*2 - x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid - y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid - for block_base in tl.static_range(0, C, BLOCK_SIZE): - block_offset = tl.arange(0, BLOCK_SIZE) + block_base - block_mask = (block_offset < C) & id_mask - grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) - dods = weight*grad - - floor_x = x.to(tl.int32) - floor_y = y.to(tl.int32) - ceil_x = floor_x + 1 - ceil_y = floor_y + 1 - - # load top left - tl_weight = (ceil_x - x) * (ceil_y - y) - tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset - tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) - tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) - tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) - dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) - dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) - dodw = dodw + tl_block_input_dot_grad * tl_weight - - dodtl = dods * tl_weight - tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) - - - # load top right - tr_weight = (x - floor_x) * (ceil_y - y) - tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset - tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) - tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) - tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) - dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) - dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) - dodw = dodw + tr_block_input_dot_grad*tr_weight - - dodtr = dods * tr_weight - tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) - - - # load bottom left - bl_weight = (ceil_x - x) * (y - floor_y) - bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset - bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) - bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) - bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) - dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) - dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) - dodw = dodw + bl_block_input_dot_grad*bl_weight - - dodbl = dods * bl_weight - tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) - - - # load bottom right - br_weight = (x - floor_x) * (y - floor_y) - br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset - br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) - br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) - br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask - - dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) - dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) - dodw = dodw + br_block_input_dot_grad*br_weight - - dodbr = dods * br_weight - tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) - dodx = dodx * weight - dody = dody * weight - tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) - tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) - tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) - - - - - diff --git a/src/ops/triton_kernels/forward.py b/src/ops/triton_kernels/forward.py deleted file mode 100644 index cf7c243..0000000 --- a/src/ops/triton_kernels/forward.py +++ /dev/null @@ -1,94 +0,0 @@ -import triton -import triton.language as tl - -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), - # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), - ], - key=['B', 'H', 'W', 'G', 'C', 'K'], -) -@triton.jit -def forward_kernel( - B: tl.constexpr, - H: tl.constexpr, # image_size_h - W: tl.constexpr, # image_size_w - G: tl.constexpr, # num_channels_per_group - C: tl.constexpr, # num_groups - K: tl.constexpr, # kernel size - input_ptr, # input features [B, H, W, G, C] - deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] - weights_ptr, # weights [B, H, W, G, K] - out_ptr, # out [B, H, W, G, C] - BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group -): - pid = tl.program_id(0) - wid = pid % W - hid = pid // W % H - gid = pid // (W * H) % G - bid = pid // (W * H * G) - - id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) - common_offset = bid*H*W*G + hid*W*G + wid*G + gid - batch_base = bid * H * W * G * C - - for block_base in tl.static_range(0, C, BLOCK_SIZE): - buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - block_offset = tl.arange(0, BLOCK_SIZE) + block_base - block_mask = (block_offset < C) & id_mask - for k in tl.static_range(K): - deformable_offset = (common_offset * K + k) * 2 - - x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid - y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid - - floor_x = x.to(tl.int32) - floor_y = y.to(tl.int32) - ceil_x = floor_x + 1 - ceil_y = floor_y + 1 - - # load top left - tl_weight = (ceil_x - x) * (ceil_y - y) - tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE - tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) - - # load top right - tr_weight = (x - floor_x) * (ceil_y - y) - tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE - tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) - # load bottom left - bl_weight = (ceil_x - x) * (y - floor_y) - bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE - bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) - # load bottom right - br_weight = (x - floor_x) * (y - floor_y) - br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE - br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) - - # load dynamic weight and mask - weights_offset = common_offset*K + k - weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) - - - - tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) - tl_block_input = tl_block_input * tl_weight - - # load top right - tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) - tr_block_input = tr_block_input * tr_weight - # load bottom left - bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) - bl_block_input = bl_block_input * bl_weight - # load bottom right - br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) - br_block_input = br_block_input * br_weight - - # sampled - sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input - - weighted_sampled_input = sampled_input * weight - buffer = buffer + weighted_sampled_input - # store to out_ptr - tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) - diff --git a/src/ops/triton_kernels/function.py b/src/ops/triton_kernels/function.py deleted file mode 100644 index 84987a1..0000000 --- a/src/ops/triton_kernels/function.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -import triton -from typing import Any -from torch.autograd import Function -from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd -from .forward import forward_kernel -from .backward import backward_kernel - - - -class DCNFunction(Function): - - @staticmethod - @custom_fwd - def forward(ctx: Any, inputs, deformables, weights) -> Any: - B, H, W, G, C = inputs.shape - _, _, _, _, K, _ = deformables.shape - out = torch.zeros_like(inputs) - grid = lambda META: (B * H * W * G,) - - forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) - ctx.save_for_backward(inputs, deformables, weights) - return out - - @staticmethod - @custom_bwd - def backward(ctx: Any, *grad_outputs: Any) -> Any: - grad_output = grad_outputs[0].contiguous() - - inputs, deformables, weights = ctx.saved_tensors - B, H, W, G, C = inputs.shape - _, _, _, _, K, _ = deformables.shape - - grad_inputs = torch.zeros_like(inputs) - grad_deformables = torch.zeros_like(deformables) - grad_weights = torch.zeros_like(weights) - grid = lambda META: (B * H * W * G,) - backward_kernel[grid]( - B, H, W, G, C, K, - inputs, - deformables, - weights, - grad_output, - grad_inputs, - grad_deformables, - grad_weights, - ) - return (grad_inputs, grad_deformables, grad_weights) \ No newline at end of file