decoupled dit src code

This commit is contained in:
wangshuai6
2025-04-09 11:07:29 +08:00
parent 06499f1caa
commit d1b6da1f0a
44 changed files with 14 additions and 8633 deletions

View File

@@ -1,68 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class PyramidTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
output_pyramid = []
def feature_hook(module, input, output):
output_pyramid.extend(output)
handle = net.decoder.register_forward_hook(feature_hook)
net(x_t, t, y)
handle.remove()
loss = 0.0
out_dict = dict()
cur_v_t = v_t
for i in range(len(output_pyramid)):
cur_out = output_pyramid[i]
loss_i = (cur_v_t - cur_out) ** 2
loss += loss_i.mean()
out_dict["loss_{}".format(i)] = loss_i.mean()
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
out_dict["loss"] = loss
return out_dict

View File

@@ -1,152 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
mask_ratio=0.0,
mask_patch_size=2,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.mask_ratio = mask_ratio
self.mask_patch_size = mask_patch_size
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
patch_mask = (patch_mask < self.mask_ratio).float()
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
x_t = alpha*masked_x + sigma*noise
v_t = dalpha*x + dsigma*noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
v_t_out, x0_out = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
mask_loss=mask_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,122 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,127 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
lpips_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.lpips_weight = lpips_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
real_score_gan = self.dis_head(real_features[-1].detach())
fake_score_gan = self.dis_head(fake_features[-1].detach())
fake_score = self.dis_head(fake_features[-1])
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,159 +0,0 @@
import random
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class MaskREPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
mask_groups=4,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.mask_groups = mask_groups
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
random_seq = torch.randperm(length, device=device)
for i in range(groups):
group_start = (length+groups-1)//groups*i
group_end = (length+groups-1)//groups*(i+1)
group_random_seq = random_seq[group_start:group_end]
y, x = torch.meshgrid(group_random_seq, group_random_seq)
mask[:, y, x] = True
return mask
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
mask_groups = random.randint(1, self.mask_groups)
mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
out, _ = net(x_t, t, y, mask=mask)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,179 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, mul=1000):
t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class BatchNormWithTimeEmbedding(nn.Module):
def __init__(self, num_features):
super().__init__()
# self.bn = nn.BatchNorm2d(num_features, affine=False)
self.bn = nn.GroupNorm(16, num_features, affine=False)
# self.bn = nn.SyncBatchNorm(num_features, affine=False)
self.embedder = TimestepEmbedder(num_features * 2)
# nn.init.zeros_(self.embedder.mlp[-1].weight)
nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
nn.init.zeros_(self.embedder.mlp[-1].bias)
def forward(self, x, t):
embed = self.embedder(t)
embed = embed[:, :, None, None]
gamma, beta = embed.chunk(2, dim=1)
gamma = 1.0 + gamma
normed = self.bn(x)
out = normed * gamma + beta
return out
class DisBlock(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.conv = nn.Conv2d(
kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
)
self.norm = BatchNormWithTimeEmbedding(hidden_size)
self.act = nn.SiLU()
def forward(self, x, t):
x = self.conv(x)
x = self.norm(x, t)
x = self.act(x)
return x
class Discriminator(nn.Module):
def __init__(self, num_blocks, in_channels, hidden_size):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(num_blocks):
self.blocks.append(
DisBlock(
in_channels=in_channels,
hidden_size=hidden_size,
)
)
in_channels = hidden_size
self.classifier = nn.Conv2d(
kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
)
def forward(self, feature, t):
B, C, H, W = feature.shape
for block in self.blocks:
feature = block(feature, t)
out = self.classifier(feature).view(B, -1)
out = out.sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_blocks=3,
adv_in_channels=3,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.discriminator = Discriminator(
num_blocks=adv_blocks,
in_channels=adv_in_channels*2,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = x_t + sigma * out
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
real_feature = torch.cat([x_t, x], dim=1)
fake_feature = torch.cat([x_t, pred_x0], dim=1)
real_score_gan = self.discriminator(real_feature.detach(), t)
fake_score_gan = self.discriminator(fake_feature.detach(), t)
fake_score = self.discriminator(fake_feature, t)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,154 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPAJiTTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
jit_deltas=0.01,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.jit_deltas = jit_deltas
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
t2 = torch.clip(t2, 0, 1)
alpha = self.scheduler.alpha(t2)
dalpha = self.scheduler.dalpha(t2)
sigma = self.scheduler.sigma(t2)
dsigma = self.scheduler.dsigma(t2)
x_t2 = alpha * x + noise * sigma
v_t2 = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
_, s = net(x_t, t, y, only_s=True)
out, _ = net(x_t2, t2, y, s)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t2)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,90 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfConsistentTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
real_features = []
def forward_hook(net, input, output):
real_features.append(output)
handles = []
for i in range(self.lpips_encoder_layer):
handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
handles.append(handle)
out, _ = net(x_t, t, y)
for handle in handles:
handle.remove()
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
_, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -1,81 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfLPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -1,78 +0,0 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class CMSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
last_step=None,
step_fn=None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.last_step = last_step
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
cfg_t = t_cur.repeat(batch_size*2)
cfg_x = torch.cat([x, x], dim=0)
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
x0 = x + v * (1-t_cur)
alpha_next = self.scheduler.alpha(t_next)
sigma_next = self.scheduler.sigma(t_next)
x = alpha_next * x0 + sigma_next * torch.randn_like(x)
# print(alpha_next, sigma_next)
return x

View File

@@ -1,149 +0,0 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class EulerSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
# init recompute
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
self.recompute_timesteps = list(range(self.num_steps))
def sharing_dp(self, net, noise, condition, uncondition):
_, C, H, W = noise.shape
B = 8
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
template_uncondition = torch.full((B, ), 1000, device=condition.device)
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
states = torch.stack(state_list)
N, B, L, C = states.shape
states = states.view(N, B*L, C )
states = states.permute(1, 0, 2)
states = torch.nn.functional.normalize(states, dim=-1)
with torch.autocast(device_type="cuda", dtype=torch.float64):
sim = torch.bmm(states, states.transpose(1, 2))
sim = torch.mean(sim, dim=0).cpu()
error_map = (1-sim).tolist()
# init cum-error
for i in range(1, self.num_steps):
for j in range(0, i):
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
# init dp and force 0 start
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
for i in range(1, self.num_steps+1):
C[1][i] = error_map[i - 1][0]
P[1][i] = 0
# dp state
for step in range(2, self.num_recompute_timesteps+1):
for i in range(step, self.num_steps+1):
min_value = 99999
min_index = -1
for j in range(step-1, i):
value = C[step-1][j] + error_map[i-1][j]
if value < min_value:
min_value = value
min_index = j
C[step][i] = min_value
P[step][i] = min_index
# trace back
timesteps = [self.num_steps,]
for i in range(self.num_recompute_timesteps, 0, -1):
idx = timesteps[-1]
timesteps.append(P[i][idx])
timesteps.reverse()
print("recompute timesteps solved by DP: ", timesteps)
return timesteps[:-1]
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
pooled_state_list = []
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
if i in self.recompute_timesteps:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
if i < self.num_steps -1 :
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
else:
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
pooled_state_list.append(state)
return x, pooled_state_list
def __call__(self, net, noise, condition, uncondition):
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
return denoised

View File

@@ -1,122 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,141 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class DistillDINOTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
sigma = self.scheduler.sigma(t)
x_t = alpha * x + noise * sigma
_, s = net(x_t, t, y)
src_feature = self.proj(s)
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
out = dict(
cos_loss=cos_loss.mean(),
loss=cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,71 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -1,74 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = False
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
fm_weight = t*(1-t)**2/0.25
lpips_weight = t
loss = (out - v_t)**2 * fm_weight[:, None, None, None]
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -1,170 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPALPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
lpips_weight=1.0,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
no_grad(self.lpips)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
if getattr(net, "blocks", None) is not None:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,383 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
x = torch.cat([x, s], dim=-1)
x = self.x_embedder(x)
for i in range(self.num_blocks):
x = self.blocks[i](x, c, pos, None)
x = self.final_layer(x, c)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
stride=self.patch_size)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None):
if s is None:
with torch.no_grad():
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -1,447 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
self.embed_proj = nn.Linear(hidden_dim, dim)
def forward(self, x, c):
c = self.embed_proj(c)[:, :, None, None]
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = x * c
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
y = self.y_embedder(y).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x, c)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x, c)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -1,448 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
self.embed_proj = nn.Linear(hidden_dim, dim)
def forward(self, x, c):
c = self.embed_proj(c)[:, :, None, None]
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = x * c
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
y = self.y_embedder(y).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x, c)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x, c)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder, "encoder.")
ModelLoader().load(decoder, "decoder.")
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -1,464 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in SiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
for block in self.down_res_blocks:
if isinstance(block, ResBlock):
nn.init.constant_(block.conv1.weight, 0)
nn.init.constant_(block.conv1.bias, 0)
nn.init.constant_(block.norm1.weight, 0)
nn.init.constant_(block.norm2.weight, 0)
nn.init.constant_(block.conv2.weight, 0)
nn.init.constant_(block.conv2.bias, 0)
for block in self.up_res_blocks:
if isinstance(block, ResBlock):
nn.init.constant_(block.conv1.weight, 0)
nn.init.constant_(block.conv1.bias, 0)
nn.init.constant_(block.norm1.weight, 0)
nn.init.constant_(block.norm2.weight, 0)
nn.init.constant_(block.conv2.weight, 0)
nn.init.constant_(block.conv2.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder, "encoder.")
ModelLoader().load(decoder, "decoder.")
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -1,274 +0,0 @@
import torch
import torch.nn as nn
import math
from numba.cuda.cudadrv.devicearray import lru_cache
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
flex_attention = torch.compile(flex_attention)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = False,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = self.norm(x)
x = modulate(x, shift, scale)
x = self.linear(x)
return x
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approximate="tanh")
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16):
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
x_pos = x_pos.reshape(-1)
y_pos = y_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1)
return freqs_cis
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
# import pdb; pdb.set_trace()
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q).to(q.dtype)
k = self.k_norm(k).to(k.dtype)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
# x = flex_attention(q, k, v, block_mask=mask)
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
x = scaled_dot_product_attention(q, k, v, mask)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class ConDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2)
self.num_cond_blocks = num_cond_blocks
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
@lru_cache
def fetch_pos(self, height, width, device):
pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...]
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H, W, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
if s is None:
# semantic encoder
s = self.s_embedder(x) + pos
c = nn.functional.silu(t + y)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -194,7 +194,7 @@ class RAttention(nn.Module):
class FlattenDiTBlock(nn.Module):
class DDTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
@@ -213,14 +213,14 @@ class FlattenDiTBlock(nn.Module):
return x
class FlattenConDiT(nn.Module):
class DDT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
num_encoder_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
@@ -236,7 +236,7 @@ class FlattenConDiT(nn.Module):
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.num_encoder_blocks = num_encoder_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
@@ -249,7 +249,7 @@ class FlattenConDiT(nn.Module):
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
DDTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
@@ -280,11 +280,6 @@ class FlattenConDiT(nn.Module):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
@@ -301,12 +296,12 @@ class FlattenConDiT(nn.Module):
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
for i in range(self.num_encoder_blocks):
s = self.blocks[i](s, c, pos, mask)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
for i in range(self.num_encoder_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)

View File

@@ -1,314 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
# s = nn.functional.silu(t + s)
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
x = torch.cat((x, s), dim=-1)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, c, pos, None)
x = self.final_layer(x, c)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,340 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConvBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3):
super().__init__()
self.hidden_size = hidden_size
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
attn_x = self.attn(attn_x)
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
x = x + gate_msa * attn_x
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
kernel_size=3,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,339 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConvBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.hidden_size = hidden_size
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
attn_x = self.attn(attn_x)
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
x = x + gate_msa * attn_x
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,314 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,429 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
# s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
s = torch.nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
stride=self.patch_size)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
joint_training=False,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
if not joint_training:
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
self.joint_training = joint_training
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiTScalingEncoder(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
no_grad(self.decoder)
if self.encoder.weight_path:
weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu'))
if self.encoder.load_ema:
prefix = "ema_denoiser."
else:
prefix = "denoiser."
for k, v in self.encoder.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
print(f"Failed to copy {prefix+k} to denoiser weight")
if self.decoder.weight_path:
weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu'))
if self.decoder.load_ema:
prefix = "ema_denoiser."
else:
prefix = "denoiser."
for k, v in self.decoder.state_dict().items():
if "blocks." in k:
blockid = int(k.split("blocks.")[-1][0])
k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}")
try:
v.copy_(weight["state_dict"][prefix+k])
except:
print(f"Failed to copy {prefix+k} to denoiser weight")
self.decoder = decoder.to(torch.bfloat16)
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -1,334 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenMLPBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = FeedForward(hidden_size, mlp_hidden_dim)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,321 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device)
s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2)
s = self.s_embedder(s)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos_s, mask)
s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size)
s = torch.permute(s, (0, 3, 1, 2))
s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False)
s = torch.permute(s, (0, 2, 3, 1))
s = s.view(B, -1, self.hidden_size)
s = nn.functional.silu(t + s)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos_x, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,311 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, masks=None):
if masks is None:
masks = [None, ]*self.num_blocks
if isinstance(masks, torch.Tensor):
masks = masks.unbind(0)
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
masks = masks + [None]*(self.num_blocks-len(masks))
B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C)
condition = nn.functional.silu(t + y)
for i, block in enumerate(self.blocks):
x = block(x, condition, pos, masks[i])
x = self.final_layer(x, condition)
x0, v = x.chunk(2, dim=-1)
x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
if self.training:
return v, x0
else:
return v

View File

@@ -1,308 +0,0 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.x_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -1,160 +0,0 @@
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from src.models.denoiser.base_model import BaseModel
from src.ops.triton_kernels.function import DCNFunction
def modulate(x, shift, scale):
return x * (1 + scale[:, None, None]) + shift[:, None, None]
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
b, h, w, c = x.shape
x = x.view(b, h*w, c)
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
x = x.view(b, h, w, c)
return x
class MultiScaleDCN(nn.Module):
def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
super().__init__()
self.in_channels = in_channels
self.groups = groups
self.channels = channels
self.kernels = kernels
self.v = nn.Linear(in_channels, groups * channels, bias=True)
self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
self.out = nn.Linear(groups * channels, in_channels)
self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
self.max_scale = 6
self._init_weights()
def _init_weights(self):
zeros_(self.qk_deformables.weight.data)
zeros_(self.qk_scales.weight.data)
zeros_(self.qk_deformables.bias.data)
zeros_(self.qk_weights.weight.data)
zeros_(self.v.bias.data)
zeros_(self.out.bias.data)
num_prior = int(self.kernels ** 0.5)
dx = torch.linspace(-1, 1, num_prior, device="cuda")
dy = torch.linspace(-1, 1, num_prior, device="cuda")
dxy = torch.meshgrid([dx, dy], indexing="xy")
dxy = torch.stack(dxy, dim=-1)
dxy = dxy.view(-1, 2)
self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
for i in range(self.groups):
scale = (i+1)/self.groups - 0.0001
inv_scale = math.log((scale)/(1-scale))
self.deformables_scale.data[..., i, :, :] = inv_scale
def forward(self, x):
B, H, W, _ = x.shape
v = self.v(x).view(B, H, W, self.groups, self.channels)
deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
out = DCNFunction.apply(v, deformables, weights)
out = out.view(B, H, W, -1)
out = self.out(out)
return out
class FlowDCNBlock(nn.Module):
def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlowDCN(BaseModel):
def __init__(self, deformable_biass=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([
FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks)
])
self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True)
self.initialize_weights()
def forward(self, x, t, y):
batch_size, _, height, width = x.shape[0]
x = self.x_embedder(x) # (N, D, h, w)
x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1)
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
B, L, C = x.shape
x = x.view(B, height//self.patch_size, width//self.patch_size, C)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = x.view(B, L, C)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size)
if self.learn_sigma:
x, _ = torch.split(x, self.out_channels // 2, dim=1)
return x

View File

@@ -194,7 +194,7 @@ class RAttention(nn.Module):
class FlattenDiTBlock(nn.Module):
class DiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
@@ -213,7 +213,7 @@ class FlattenDiTBlock(nn.Module):
return x
class FlattenDiT(nn.Module):
class DiT(nn.Module):
def __init__(
self,
in_channels=4,
@@ -246,7 +246,7 @@ class FlattenDiT(nn.Module):
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
@@ -272,11 +272,6 @@ class FlattenDiT(nn.Module):
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)

View File

@@ -1,346 +0,0 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <cuda_runtime.h>
#include <cuda_fp16.hpp>
#include <cuda_bf16.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <pybind11/pybind11.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda/pipeline>
namespace cg = cooperative_groups;
template<typename scalar_t>
__device__ __always_inline int toInt(scalar_t val);
template<>
__device__ __always_inline int toInt(float val){
return static_cast<int>(val);
}
template<>
__device__ __always_inline int toInt(half val){
return __half2int_rz(val);
}
template<typename scalar_t>
__device__ __always_inline scalar_t fromInt(int val);
template<>
__device__ __always_inline float fromInt(int val){
return static_cast<float>(val);
}
template<>
__device__ __always_inline half fromInt(int val){
return __int2half_rz(val);
}
template<typename scalar_t>
__device__ __always_inline scalar_t constVal(float val);
template<>
__device__ __always_inline float constVal<float>(float val) {
return (float)val;
}
template<>
__device__ __always_inline half constVal<half>(float val) {
return __float2half(val); // Using float to half conversion
}
template<>
__device__ __always_inline nv_bfloat16 constVal<nv_bfloat16>(float val){
return __float2bfloat16(val);
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename scalar_t, typename vec2_t, int pipeline_stages, int TILE_C, int TILE_THREADS>
__global__ void dcn_backward_pipeline_kernel(
const int H,
const int W,
const int G,
const int K,
const int C,
scalar_t* ptr_values,
scalar_t* ptr_deformables,
scalar_t* ptr_weights,
scalar_t* ptr_grad_out,
scalar_t* ptr_grad_values,
scalar_t* ptr_grad_deformables,
scalar_t* ptr_grad_weights
) {
auto block = cg::this_thread_block();
auto self_thread = cg::this_thread();
auto tile_threads = cg::tiled_partition<TILE_THREADS>(block);
int local_thread_id = block.thread_rank();
int local_tile_id = tile_threads.meta_group_rank();
int num_local_tiles = tile_threads.meta_group_size();
int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id;
extern __shared__ int shm[];
auto GradBuffer = reinterpret_cast<scalar_t*>(shm);
scalar_t* Buffer = reinterpret_cast<scalar_t*>(shm) + num_local_tiles*C;
if(global_tile_id >= H*W*G) return;
int bid = block.group_index().y;
int gid = global_tile_id % G;
int wid = global_tile_id / G % W;
int hid = global_tile_id / G / W;
int globale_offset = bid*H*W*G*C + global_tile_id*C;
cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C);
int shared_offset[pipeline_stages];
for (int s = 0; s < pipeline_stages; ++s) {
shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4);
}
auto pipeline = cuda::make_pipeline();
const int num_tiles_per_thread = C/TILE_C/TILE_THREADS;
for(int k=0; k<K; k++) {
int offset = bid * K * H * W * G + hid * W * K * G + wid * K * G + gid * K + k;
scalar_t x, y, weight;
if (tile_threads.thread_rank() == 0) {
x = ptr_deformables[offset*2] + fromInt<scalar_t>(wid);
y = ptr_deformables[offset*2 + 1] + fromInt<scalar_t>(hid);
// x = fromInt<scalar_t>(wid);
// y = fromInt<scalar_t>(hid);
weight = ptr_weights[offset];
}
tile_threads.sync();
x = tile_threads.shfl(x, 0);
y = tile_threads.shfl(y, 0);
weight = tile_threads.shfl(weight, 0);
int floor_x = toInt<scalar_t>(x);
int floor_y = toInt<scalar_t>(y);
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
scalar_t dodx = constVal<scalar_t>(0.0f);
scalar_t dody = constVal<scalar_t>(0.0f);
scalar_t dodw = constVal<scalar_t>(0.0f);
int start_c = tile_threads.thread_rank() * (C / TILE_THREADS);
bool tl_flag = (floor_x >=0) and (floor_x <W) and (floor_y>=0) and (floor_y<H);
bool tr_flag = (ceil_x >=0) and (ceil_x <W) and (floor_y>=0) and (floor_y<H);
bool bl_flag = (floor_x >=0) and (floor_x <W) and (ceil_y>=0) and (ceil_y<H);
bool br_flag = (ceil_x >=0) and (ceil_x <W) and (ceil_y>=0) and (ceil_y<H);
int tl_global_base = (bid * H * W * G + floor_y * W * G + floor_x * G + gid)*C + start_c;
int tr_global_base = (bid * H * W * G + floor_y * W * G + ceil_x * G + gid)*C + start_c;
int bl_global_base = (bid * H * W * G + ceil_y * W * G + floor_x * G + gid)*C +start_c;
int br_global_base = (bid * H * W * G + ceil_y * W * G + ceil_x * G + gid)*C +start_c;
auto asmem_load_fn = [&](int shm_offset, int hbm_offset, bool flag){
if(flag){
cuda::memcpy_async(Buffer + shm_offset, ptr_values + hbm_offset,
TILE_C * sizeof(scalar_t), pipeline);
}else{
memset(Buffer+shm_offset, TILE_C, sizeof(scalar_t));
}
};
// pipeline-compute&load
for (int compute_n = 0, fetch_n=0; compute_n < num_tiles_per_thread; compute_n++) {
for (; fetch_n < compute_n + pipeline_stages and fetch_n < num_tiles_per_thread; fetch_n++) {
pipeline.producer_acquire();
int buffer_offset = shared_offset[fetch_n % pipeline_stages];
// tl
asmem_load_fn(buffer_offset, tl_global_base + fetch_n * TILE_C, tl_flag);
// tr
asmem_load_fn(buffer_offset+TILE_C, tr_global_base + fetch_n * TILE_C, tr_flag);
// bl
asmem_load_fn(buffer_offset+TILE_C*2, bl_global_base + fetch_n * TILE_C, bl_flag);
// br
asmem_load_fn(buffer_offset+TILE_C*3, br_global_base + fetch_n * TILE_C, br_flag);
pipeline.producer_commit();
}
pipeline.consumer_wait();
int buffer_id = compute_n % pipeline_stages;
int ibuffer_offset = shared_offset[buffer_id];
int gbuffer_offset = local_tile_id * C + start_c + compute_n * TILE_C;
for (int j = 0; j < TILE_C; j+=2) {
if(tl_flag){
// tl
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1];
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
{
vec2_t vtl_di;
vtl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
vtl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1];
atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di);
}
}
if(tr_flag){
// tr
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1];
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vtr_di;
vtr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
vtr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di);
}
}
if(bl_flag){
// bl
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vbl_di;
vbl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
vbl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di);
}
}
if(br_flag){
// tr
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vbr_di;
vbr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
vbr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di);
}
}
}
pipeline.consumer_release();
}
for (int i = TILE_THREADS>>1; i > 0; i/=2) {
dodx = dodx + tile_threads.shfl_down(dodx, i);
dody = dody + tile_threads.shfl_down(dody, i);
dodw = dodw + tile_threads.shfl_down(dodw, i);
}
if (tile_threads.thread_rank() == 0) {
cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline);
cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline);
cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline);
}
}
}
using namespace torch;
template<int pipeline_stages, int TILE_C, int TILE_THREADS, int THREADS>
void backward(const int B,
const int H,
const int W,
const int G,
const int K,
const int C,
torch::Tensor values,
torch::Tensor deformables,
torch::Tensor weights,
torch::Tensor grad_out,
torch::Tensor grad_values,
torch::Tensor grad_deformables,
torch::Tensor grad_weights
) {
int num_local_tiles =(THREADS/TILE_THREADS);
int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(num_global_tiles, B);
int deformable_shm_size = 0;
int grad_out_shm_size = num_local_tiles*C;
int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS);
int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size;
// printf("shm_size: %d\n", shm_size/512);
// printf("pipeline_size: %d\n", pipeline_shm_size/512);
// printf("grad_out_size: %d\n", grad_out_shm_size/512);
switch (values.type().scalarType()) {
case at::ScalarType::Half:
return dcn_backward_pipeline_kernel<half, half2, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(half)>>>(
H, W, G, K, C,
reinterpret_cast<half*>(values.data_ptr<at::Half>()),
reinterpret_cast<half*>(deformables.data_ptr<at::Half>()),
reinterpret_cast<half*>(weights.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_out.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_values.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_deformables.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_weights.data_ptr<at::Half>())
);
// case at::ScalarType::BFloat16:
// return dcn_backward_pipeline_kernel<nv_bfloat16, nv_bfloat162, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(nv_bfloat16)>>>(
// H, W, G, K, C,
// reinterpret_cast<nv_bfloat16*>(values.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(deformables.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(weights.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_out.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_values.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_deformables.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_weights.data_ptr<at::BFloat16>())
// );
default:
printf("running error");
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, "");
m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, "");
m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, "");
m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, "");
m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, "");
m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, "");
m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, "");
m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, "");
m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, "");
m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, "");
m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, "");
m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, "");
m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, "");
m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, "");
m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, "");
// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, "");
// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, "");
// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, "");
m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, "");
m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, "");
m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, "");
m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, "");
}

View File

@@ -1,289 +0,0 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_pipeline_primitives.h>
#include <cuda_fp16.hpp>
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b));
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA>
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = 0;
ptr_a += stride;
}
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
// __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
for(int j=0; j<num_transfers; j++){
if((base_c+j) < C){
// __pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
math_buffer[job_id][j] = (math_t)*(ptr_value + offset2 +j);
}
}
}
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
// loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
// loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
#pragma unroll
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
}
}
}
__syncthreads();
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
__shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM/transfer_length;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
}
__syncthreads();
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
// int offset1 = job_id*num_transfers;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
#pragma unroll
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
}
}
}
}
template<int L, int C_BLOCK_DIM, int THREADS>
void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_16<at::Half, at::Half, 4, 9, L, (C_BLOCK_DIM)><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::BFloat16:
return dcn_forward_kernel_16<at::BFloat16, at::BFloat16, 4, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<at::BFloat16>(),
deformables.data_ptr<at::BFloat16>(),
weights.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>());
case at::ScalarType::Float:
return dcn_forward_kernel<at::Half, float, 2, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<float>(),
deformables.data_ptr<float>(),
weights.data_ptr<float>(),
out.data_ptr<float>());
default:
printf("running error");
}
}
// PyBind11 bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
}

View File

@@ -1,309 +0,0 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_pipeline_primitives.h>
#include <cuda_fp16.hpp>
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b));
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA>
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = 0;
ptr_a += stride;
}
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int K, int BLOCK_DIM>
__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
extern __shared__ int shm[];
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
math_buffer[job_id*BLOCK_DIM +j] = (math_t)*(ptr_value + offset2 +j);
}
}
}
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
#pragma unroll
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
}
}
}
__syncthreads();
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int BLOCK_DIM>
__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
extern __shared__ int shm[];
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
scalar_t* io_buffer = reinterpret_cast<scalar_t*>(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t);
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM/transfer_length;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
}
__syncthreads();
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
// int offset1 = job_id*num_transfers;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
#pragma unroll
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
}
}
}
}
template<int C_BLOCK_DIM, int THREADS>
void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_register<at::Half, at::Half, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::Float:
return dcn_forward_kernel_register<at::Half, float, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<float>(),
deformables.data_ptr<float>(),
weights.data_ptr<float>(),
out.data_ptr<float>());
default:
printf("running error");
}
}
template<int C_BLOCK_DIM, int THREADS>
void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_pipeline<at::Half, at::Half, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::BFloat16:
return dcn_forward_kernel_pipeline<at::BFloat16, at::BFloat16, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::BFloat16>(),
deformables.data_ptr<at::BFloat16>(),
weights.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>());
default:
printf("running error");
}
}
// PyBind11 bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
}

View File

@@ -1,95 +0,0 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_channels_per_group
C: tl.constexpr, # num_groups
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
weights_ptr, # weights [B, H, W, G, K]
out_ptr, # out [B, H, W, G, C]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for block_base in tl.static_range(0, C, BLOCK_SIZE):
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
for k in tl.static_range(K):
deformable_offset = (common_offset * K + k) * 2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input = tl_block_input * tl_weight
# load top right
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input = tr_block_input * tr_weight
# load bottom left
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input = bl_block_input * bl_weight
# load bottom right
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input = br_block_input * br_weight
# sampled
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
weighted_sampled_input = sampled_input * weight
buffer = buffer + weighted_sampled_input
# store to out_ptr
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)

View File

@@ -1,126 +0,0 @@
import time
import dcn_cuda_backward
import dcn_cuda_forward
import math
import torch
from typing import Any
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_fwd, custom_bwd
from .forward import forward_kernel
class DCNFunction(Function):
BP_FUNCS = [
dcn_cuda_backward.backward_p1_c2_tile16_thread128,
dcn_cuda_backward.backward_p1_c4_tile16_thread128,
dcn_cuda_backward.backward_p2_c2_tile16_thread128,
dcn_cuda_backward.backward_p1_c2_tile16_thread256,
dcn_cuda_backward.backward_p1_c4_tile16_thread256,
dcn_cuda_backward.backward_p2_c2_tile16_thread256,
dcn_cuda_backward.backward_p1_c2_tile16_thread384,
dcn_cuda_backward.backward_p1_c4_tile16_thread384,
dcn_cuda_backward.backward_p2_c2_tile16_thread384,
dcn_cuda_backward.backward_p1_c2_tile16_thread512,
dcn_cuda_backward.backward_p1_c4_tile16_thread512,
dcn_cuda_backward.backward_p2_c2_tile16_thread512,
dcn_cuda_backward.backward_p1_c2_tile16_thread768,
dcn_cuda_backward.backward_p1_c4_tile16_thread768,
dcn_cuda_backward.backward_p2_c2_tile16_thread768,
dcn_cuda_backward.backward_p1_c2_tile32_thread128,
dcn_cuda_backward.backward_p1_c2_tile32_thread256,
dcn_cuda_backward.backward_p1_c2_tile32_thread384,
dcn_cuda_backward.backward_p1_c2_tile32_thread512,
]
FW_FUNCS = [
dcn_cuda_forward.dcn_forward_l256_c4,
dcn_cuda_forward.dcn_forward_l256_c8,
dcn_cuda_forward.dcn_forward_l256_c16,
]
BP_TABLES = dict()
FW_TABLES = dict()
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, values, deformables, weights) -> Any:
B, H, W, G, C = values.shape
func = DCNFunction.find_fw_funcs(values, deformables, weights)
out = torch.zeros_like(values)
func(B, G, C, H, W, values, deformables, weights, out)
return out
@staticmethod
def find_fw_funcs(values, deformables, weights):
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
hash_value = 10000 * B + 100 * H + W + 1000 * G
if hash_value in DCNFunction.FW_TABLES.keys():
return DCNFunction.FW_TABLES[hash_value]
print("missing")
candicate_func = None
min_t = 999.0
outs = torch.zeros_like(values)
for func in DCNFunction.FW_FUNCS:
t = []
for i in range(100):
torch.cuda.synchronize()
start_t = time.time()
func(B, G, C, H, W, values, deformables, weights, outs)
torch.cuda.synchronize()
t.append(time.time() - start_t)
t = t[-50:]
t = sum(t) / len(t)
if t < min_t:
min_t = t
DCNFunction.FW_TABLES[hash_value] = func
candicate_func = func
assert candicate_func is not None
print(candicate_func)
return candicate_func
@staticmethod
def find_bp_funcs(values, deformables, weights, grad_out):
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
hash_value = 10000 * B + 100 * H + W + 1000 * G
if hash_value in DCNFunction.BP_TABLES.keys():
return DCNFunction.BP_TABLES[hash_value]
print("missing")
candicate_func = None
min_t = 999.0
grad_values = torch.zeros_like(values)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
for func in DCNFunction.BP_FUNCS:
t = []
for i in range(100):
torch.cuda.synchronize()
start_t = time.time()
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
torch.cuda.synchronize()
t.append(time.time() - start_t)
t = t[-50:]
t = sum(t) / len(t)
if t < min_t:
min_t = t
DCNFunction.BP_TABLES[hash_value] = func
candicate_func = func
assert candicate_func is not None
print(candicate_func)
return candicate_func
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx: Any, *grad_outputs: Any) -> Any:
grad_out = grad_outputs[0]
values, deformables, weights = ctx.saved_tensors
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out)
grad_values = torch.zeros_like(values)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
return grad_values, grad_deformables, grad_weights

View File

@@ -1,59 +0,0 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='dcn_cuda_forward',
ext_modules=[
CUDAExtension('dcn_cuda_forward', ['./forward.cu',],
extra_compile_args={'cxx': [], 'nvcc': [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"--use_fast_math",
"-O3",
]}
),
],
cmdclass={
'build_ext': BuildExtension
}
)
setup(
name='dcn_cuda_backward',
ext_modules=[
CUDAExtension('dcn_cuda_backward', ['./backward.cu',],
extra_compile_args={'cxx': [], 'nvcc': [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"--use_fast_math",
"-O3",
]}
),
],
cmdclass={
'build_ext': BuildExtension
}
)
# setup(
# name='mycuda',
# ext_modules=[
# CUDAExtension('mycuda', ['./backward.cu',],
# extra_compile_args={'cxx': [], 'nvcc': [
# "-O3",
# "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
# "-D__CUDA_NO_HALF2_OPERATORS__",
# ]}
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# }
# )

View File

@@ -1,124 +0,0 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def backward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_groups
C: tl.constexpr, # num_channels_per_group
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
weights_ptr, # weights [B, H, W, G, K]
grad_ptr, # out [B, H, W, G, C]
grad_input_ptr, # input features [B, H, W, G, C]
grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
grad_weights_ptr, # weights [B, H, W, G, K]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for k in tl.static_range(K):
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
deformable_offset = (common_offset * K + k)*2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
for block_base in tl.static_range(0, C, BLOCK_SIZE):
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
dods = weight*grad
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + tl_block_input_dot_grad * tl_weight
dodtl = dods * tl_weight
tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
dodw = dodw + tr_block_input_dot_grad*tr_weight
dodtr = dods * tr_weight
tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + bl_block_input_dot_grad*bl_weight
dodbl = dods * bl_weight
tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
dodw = dodw + br_block_input_dot_grad*br_weight
dodbr = dods * br_weight
tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
dodx = dodx * weight
dody = dody * weight
tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)

View File

@@ -1,94 +0,0 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
# triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_channels_per_group
C: tl.constexpr, # num_groups
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
weights_ptr, # weights [B, H, W, G, K]
out_ptr, # out [B, H, W, G, C]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for block_base in tl.static_range(0, C, BLOCK_SIZE):
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
for k in tl.static_range(K):
deformable_offset = (common_offset * K + k) * 2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input = tl_block_input * tl_weight
# load top right
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input = tr_block_input * tr_weight
# load bottom left
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input = bl_block_input * bl_weight
# load bottom right
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input = br_block_input * br_weight
# sampled
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
weighted_sampled_input = sampled_input * weight
buffer = buffer + weighted_sampled_input
# store to out_ptr
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)

View File

@@ -1,48 +0,0 @@
import torch
import triton
from typing import Any
from torch.autograd import Function
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
from .forward import forward_kernel
from .backward import backward_kernel
class DCNFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx: Any, inputs, deformables, weights) -> Any:
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
out = torch.zeros_like(inputs)
grid = lambda META: (B * H * W * G,)
forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
ctx.save_for_backward(inputs, deformables, weights)
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Any) -> Any:
grad_output = grad_outputs[0].contiguous()
inputs, deformables, weights = ctx.saved_tensors
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
grad_inputs = torch.zeros_like(inputs)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
grid = lambda META: (B * H * W * G,)
backward_kernel[grid](
B, H, W, G, C, K,
inputs,
deformables,
weights,
grad_output,
grad_inputs,
grad_deformables,
grad_weights,
)
return (grad_inputs, grad_deformables, grad_weights)