decoupled dit src code

This commit is contained in:
wangshuai6
2025-04-09 11:07:29 +08:00
parent 06499f1caa
commit d1b6da1f0a
44 changed files with 14 additions and 8633 deletions

View File

@@ -1,122 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,127 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
lpips_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.lpips_weight = lpips_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
real_score_gan = self.dis_head(real_features[-1].detach())
fake_score_gan = self.dis_head(fake_features[-1].detach())
fake_score = self.dis_head(fake_features[-1])
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,159 +0,0 @@
import random
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class MaskREPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
mask_groups=4,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.mask_groups = mask_groups
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
random_seq = torch.randperm(length, device=device)
for i in range(groups):
group_start = (length+groups-1)//groups*i
group_end = (length+groups-1)//groups*(i+1)
group_random_seq = random_seq[group_start:group_end]
y, x = torch.meshgrid(group_random_seq, group_random_seq)
mask[:, y, x] = True
return mask
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
mask_groups = random.randint(1, self.mask_groups)
mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
out, _ = net(x_t, t, y, mask=mask)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,179 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, mul=1000):
t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class BatchNormWithTimeEmbedding(nn.Module):
def __init__(self, num_features):
super().__init__()
# self.bn = nn.BatchNorm2d(num_features, affine=False)
self.bn = nn.GroupNorm(16, num_features, affine=False)
# self.bn = nn.SyncBatchNorm(num_features, affine=False)
self.embedder = TimestepEmbedder(num_features * 2)
# nn.init.zeros_(self.embedder.mlp[-1].weight)
nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
nn.init.zeros_(self.embedder.mlp[-1].bias)
def forward(self, x, t):
embed = self.embedder(t)
embed = embed[:, :, None, None]
gamma, beta = embed.chunk(2, dim=1)
gamma = 1.0 + gamma
normed = self.bn(x)
out = normed * gamma + beta
return out
class DisBlock(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.conv = nn.Conv2d(
kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
)
self.norm = BatchNormWithTimeEmbedding(hidden_size)
self.act = nn.SiLU()
def forward(self, x, t):
x = self.conv(x)
x = self.norm(x, t)
x = self.act(x)
return x
class Discriminator(nn.Module):
def __init__(self, num_blocks, in_channels, hidden_size):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(num_blocks):
self.blocks.append(
DisBlock(
in_channels=in_channels,
hidden_size=hidden_size,
)
)
in_channels = hidden_size
self.classifier = nn.Conv2d(
kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
)
def forward(self, feature, t):
B, C, H, W = feature.shape
for block in self.blocks:
feature = block(feature, t)
out = self.classifier(feature).view(B, -1)
out = out.sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_blocks=3,
adv_in_channels=3,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.discriminator = Discriminator(
num_blocks=adv_blocks,
in_channels=adv_in_channels*2,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = x_t + sigma * out
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
real_feature = torch.cat([x_t, x], dim=1)
fake_feature = torch.cat([x_t, pred_x0], dim=1)
real_score_gan = self.discriminator(real_feature.detach(), t)
fake_score_gan = self.discriminator(fake_feature.detach(), t)
fake_score = self.discriminator(fake_feature, t)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,154 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPAJiTTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
jit_deltas=0.01,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.jit_deltas = jit_deltas
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
t2 = torch.clip(t2, 0, 1)
alpha = self.scheduler.alpha(t2)
dalpha = self.scheduler.dalpha(t2)
sigma = self.scheduler.sigma(t2)
dsigma = self.scheduler.dsigma(t2)
x_t2 = alpha * x + noise * sigma
v_t2 = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
_, s = net(x_t, t, y, only_s=True)
out, _ = net(x_t2, t2, y, s)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t2)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,90 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfConsistentTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
real_features = []
def forward_hook(net, input, output):
real_features.append(output)
handles = []
for i in range(self.lpips_encoder_layer):
handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
handles.append(handle)
out, _ = net(x_t, t, y)
for handle in handles:
handle.remove()
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
_, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -1,81 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfLPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -1,78 +0,0 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class CMSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
last_step=None,
step_fn=None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.last_step = last_step
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
cfg_t = t_cur.repeat(batch_size*2)
cfg_x = torch.cat([x, x], dim=0)
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
x0 = x + v * (1-t_cur)
alpha_next = self.scheduler.alpha(t_next)
sigma_next = self.scheduler.sigma(t_next)
x = alpha_next * x0 + sigma_next * torch.randn_like(x)
# print(alpha_next, sigma_next)
return x

View File

@@ -1,149 +0,0 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class EulerSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
# init recompute
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
self.recompute_timesteps = list(range(self.num_steps))
def sharing_dp(self, net, noise, condition, uncondition):
_, C, H, W = noise.shape
B = 8
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
template_uncondition = torch.full((B, ), 1000, device=condition.device)
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
states = torch.stack(state_list)
N, B, L, C = states.shape
states = states.view(N, B*L, C )
states = states.permute(1, 0, 2)
states = torch.nn.functional.normalize(states, dim=-1)
with torch.autocast(device_type="cuda", dtype=torch.float64):
sim = torch.bmm(states, states.transpose(1, 2))
sim = torch.mean(sim, dim=0).cpu()
error_map = (1-sim).tolist()
# init cum-error
for i in range(1, self.num_steps):
for j in range(0, i):
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
# init dp and force 0 start
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
for i in range(1, self.num_steps+1):
C[1][i] = error_map[i - 1][0]
P[1][i] = 0
# dp state
for step in range(2, self.num_recompute_timesteps+1):
for i in range(step, self.num_steps+1):
min_value = 99999
min_index = -1
for j in range(step-1, i):
value = C[step-1][j] + error_map[i-1][j]
if value < min_value:
min_value = value
min_index = j
C[step][i] = min_value
P[step][i] = min_index
# trace back
timesteps = [self.num_steps,]
for i in range(self.num_recompute_timesteps, 0, -1):
idx = timesteps[-1]
timesteps.append(P[i][idx])
timesteps.reverse()
print("recompute timesteps solved by DP: ", timesteps)
return timesteps[:-1]
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
pooled_state_list = []
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
if i in self.recompute_timesteps:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
if i < self.num_steps -1 :
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
else:
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
pooled_state_list.append(state)
return x, pooled_state_list
def __call__(self, net, noise, condition, uncondition):
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
return denoised

View File

@@ -1,122 +0,0 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -1,141 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class DistillDINOTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
sigma = self.scheduler.sigma(t)
x_t = alpha * x + noise * sigma
_, s = net(x_t, t, y)
src_feature = self.proj(s)
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
out = dict(
cos_loss=cos_loss.mean(),
loss=cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -1,71 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -1,74 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = False
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
fm_weight = t*(1-t)**2/0.25
lpips_weight = t
loss = (out - v_t)**2 * fm_weight[:, None, None, None]
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -1,170 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPALPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
lpips_weight=1.0,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
no_grad(self.lpips)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
if getattr(net, "blocks", None) is not None:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)