decoupled dit src code

This commit is contained in:
wangshuai6
2025-04-09 11:07:29 +08:00
parent 06499f1caa
commit d1b6da1f0a
44 changed files with 14 additions and 8633 deletions

View File

@@ -1,68 +0,0 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class PyramidTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
output_pyramid = []
def feature_hook(module, input, output):
output_pyramid.extend(output)
handle = net.decoder.register_forward_hook(feature_hook)
net(x_t, t, y)
handle.remove()
loss = 0.0
out_dict = dict()
cur_v_t = v_t
for i in range(len(output_pyramid)):
cur_out = output_pyramid[i]
loss_i = (cur_v_t - cur_out) ** 2
loss += loss_i.mean()
out_dict["loss_{}".format(i)] = loss_i.mean()
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
out_dict["loss"] = loss
return out_dict

View File

@@ -1,152 +0,0 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
mask_ratio=0.0,
mask_patch_size=2,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.mask_ratio = mask_ratio
self.mask_patch_size = mask_patch_size
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
patch_mask = (patch_mask < self.mask_ratio).float()
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
x_t = alpha*masked_x + sigma*noise
v_t = dalpha*x + dsigma*noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
v_t_out, x0_out = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
mask_loss=mask_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)