submit code
This commit is contained in:
0
src/diffusion/__init__.py
Normal file
0
src/diffusion/__init__.py
Normal file
60
src/diffusion/base/guidance.py
Normal file
60
src/diffusion/base/guidance.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
def simple_guidance_fn(out, cfg):
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = uncondition + cfg * (condtion - uncondition)
|
||||
return out
|
||||
|
||||
def c3_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
|
||||
return out
|
||||
|
||||
def c4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p05_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p10_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p15_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p20_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def p4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
31
src/diffusion/base/sampling.py
Normal file
31
src/diffusion/base/sampling.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Callable
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
class BaseSampler(nn.Module):
|
||||
def __init__(self,
|
||||
scheduler: BaseScheduler = None,
|
||||
guidance_fn: Callable = None,
|
||||
num_steps: int = 250,
|
||||
guidance: Union[float, List[float]] = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(BaseSampler, self).__init__()
|
||||
self.num_steps = num_steps
|
||||
self.guidance = guidance
|
||||
self.guidance_fn = guidance_fn
|
||||
self.scheduler = scheduler
|
||||
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, noise, condition, uncondition):
|
||||
denoised = self._impl_sampling(net, noise, condition, uncondition)
|
||||
return denoised
|
||||
|
||||
|
||||
32
src/diffusion/base/scheduling.py
Normal file
32
src/diffusion/base/scheduling.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
class BaseScheduler:
|
||||
def alpha(self, t) -> Tensor:
|
||||
...
|
||||
def sigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
...
|
||||
def dsigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha_over_alpha(self, t) -> Tensor:
|
||||
return self.dalpha(t) / self.alpha(t)
|
||||
|
||||
def dsigma_mul_sigma(self, t) -> Tensor:
|
||||
return self.dsigma(t)*self.sigma(t)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dalpha/(alpha + 1e-6)
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
|
||||
|
||||
def w(self, t):
|
||||
return self.sigma(t)
|
||||
29
src/diffusion/base/training.py
Normal file
29
src/diffusion/base/training.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseTrainer(nn.Module):
|
||||
def __init__(self,
|
||||
null_condition_p=0.1,
|
||||
log_var=False,
|
||||
):
|
||||
super(BaseTrainer, self).__init__()
|
||||
self.null_condition_p = null_condition_p
|
||||
self.log_var = log_var
|
||||
|
||||
def preproprocess(self, raw_iamges, x, condition, uncondition):
|
||||
bsz = x.shape[0]
|
||||
if self.null_condition_p > 0:
|
||||
mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
|
||||
mask = mask.expand_as(condition)
|
||||
condition[mask] = uncondition[mask]
|
||||
return raw_iamges, x, condition
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
|
||||
raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
|
||||
return self._impl_trainstep(net, ema_net, raw_images, x, condition)
|
||||
|
||||
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DDIMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_num_steps=1000,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.train_num_steps = train_num_steps
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
|
||||
steps = torch.flip(steps, dims=[0])
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
t_next = t_next.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha = self.scheduler.alpha(t_cur)
|
||||
sigma_next = self.scheduler.sigma(t_next)
|
||||
alpha_next = self.scheduler.alpha(t_next)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
t = t_cur.repeat(2)
|
||||
out = net(cfg_x, t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
x0 = (x - sigma * out) / alpha
|
||||
x = alpha_next * x0 + sigma_next * out
|
||||
return x0
|
||||
102
src/diffusion/ddpm/scheduling.py
Normal file
102
src/diffusion/ddpm/scheduling.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class DDPMScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.0001,
|
||||
beta_max=0.02,
|
||||
num_steps=1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_max = beta_max
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
|
||||
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
|
||||
self.sigmas_table = 1-self.alphas_table
|
||||
|
||||
|
||||
def beta(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.betas_table[t].view(-1, 1, 1, 1)
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
|
||||
class VPScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.1,
|
||||
beta_max=20,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_d = beta_max - beta_min
|
||||
def beta(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
|
||||
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
|
||||
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
return self.diffuse_coefficient(t)
|
||||
|
||||
|
||||
|
||||
83
src/diffusion/ddpm/training.py
Normal file
83
src/diffusion/ddpm/training.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class VPTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t*self.train_max_t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - noise)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class DDPMTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn: Callable = constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
t = torch.randint(0, self.train_max_t, (batch_size,))
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight * (out - noise) ** 2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/ddpm/vp_sampling.py
Normal file
59
src/diffusion/ddpm/vp_sampling.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
from typing import Callable
|
||||
|
||||
def ode_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
|
||||
|
||||
def sde_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VPEulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_max_t=1000,
|
||||
guidance_fn: Callable = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.guidance_fn = guidance_fn
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
|
||||
steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
beta = self.scheduler.beta(t_cur)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
|
||||
eps = self.guidance_fn(out, self.guidance)
|
||||
if i < self.num_steps -1 :
|
||||
x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
|
||||
x = self.step_fn(x, eps, beta, sigma, dt)
|
||||
else:
|
||||
x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
|
||||
return x
|
||||
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
from src.diffusion.base.sampling import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.pre_integral import *
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def t2snr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2snr(t) for t in t]
|
||||
t = max(t, 1e-8)
|
||||
return (t/(1-t + 1e-8))
|
||||
|
||||
def t2logsnr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2logsnr(t) for t in t]
|
||||
t = max(t, 1e-3)
|
||||
return math.log(t/(1-t + 1e-3))
|
||||
|
||||
def t2isnr(t):
|
||||
return 1/t2snr(t)
|
||||
|
||||
def nop(t):
|
||||
return t
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdamLMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
|
||||
def _reparameterize_coeffs(self):
|
||||
solver_coeffs = [[] for _ in range(self.num_steps)]
|
||||
for i in range(0, self.num_steps):
|
||||
pre_vs = [1.0, ]*(i+1)
|
||||
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
||||
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
||||
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
||||
|
||||
order_annealing = self.order #self.num_steps - i
|
||||
order = min(self.order, i + 1, order_annealing)
|
||||
|
||||
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
||||
solver_coeffs[i] = coeffs
|
||||
self.solver_coeffs = solver_coeffs
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
pred_trajectory = []
|
||||
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
||||
timedeltas = self.timedeltas
|
||||
solver_coeffs = self.solver_coeffs
|
||||
for i in range(self.num_steps):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
for j in range(order):
|
||||
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
||||
v = out
|
||||
dt = timedeltas[i]
|
||||
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
||||
x = self.step_fn(x, v, dt, s=0, w=0)
|
||||
t_cur += dt
|
||||
return x
|
||||
179
src/diffusion/flow_matching/sampling.py
Normal file
179
src/diffusion/flow_matching/sampling.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def sde_mean_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt + s * w * dt
|
||||
|
||||
def sde_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
||||
|
||||
def sde_preserve_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
|
||||
|
||||
class HeunSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler = None,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
exact_henu=False,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scheduler = scheduler
|
||||
self.exact_henu = exact_henu
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Henu sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
v_hat, s_hat = 0.0, 0.0
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
t_hat = t_next
|
||||
t_hat = t_hat.repeat(batch_size)
|
||||
sigma_hat = self.scheduler.sigma(t_hat)
|
||||
alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
|
||||
dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
|
||||
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
if i == 0 or self.exact_henu:
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t_cur = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t_cur, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
|
||||
else:
|
||||
v = v_hat
|
||||
s = s_hat
|
||||
x_hat = self.step_fn(x, v, dt, s=s, w=w)
|
||||
# henu correct
|
||||
if i < self.num_steps -1:
|
||||
cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
|
||||
cfg_t_hat = t_hat.repeat(2)
|
||||
out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v_hat = out
|
||||
s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
|
||||
v = (v + v_hat) / 2
|
||||
s = (s + s_hat) / 2
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
39
src/diffusion/flow_matching/scheduling.py
Normal file
39
src/diffusion/flow_matching/scheduling.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class LinearScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return (t).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return (1-t).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
||||
|
||||
# SoTA for ImageNet!
|
||||
class GVPScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def w(self, t):
|
||||
return torch.sin(t)**2
|
||||
|
||||
class ConstScheduler(BaseScheduler):
|
||||
def w(self, t):
|
||||
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
||||
|
||||
from src.diffusion.ddpm.scheduling import VPScheduler
|
||||
class VPBetaScheduler(VPScheduler):
|
||||
def w(self, t):
|
||||
return self.beta(t).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
|
||||
55
src/diffusion/flow_matching/training.py
Normal file
55
src/diffusion/flow_matching/training.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class FlowMatchingTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/flow_matching/training_cos.py
Normal file
59
src/diffusion/flow_matching/training_cos.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class COSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class PyramidTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
|
||||
output_pyramid = []
|
||||
def feature_hook(module, input, output):
|
||||
output_pyramid.extend(output)
|
||||
handle = net.decoder.register_forward_hook(feature_hook)
|
||||
net(x_t, t, y)
|
||||
handle.remove()
|
||||
|
||||
loss = 0.0
|
||||
out_dict = dict()
|
||||
|
||||
cur_v_t = v_t
|
||||
for i in range(len(output_pyramid)):
|
||||
cur_out = output_pyramid[i]
|
||||
loss_i = (cur_v_t - cur_out) ** 2
|
||||
loss += loss_i.mean()
|
||||
out_dict["loss_{}".format(i)] = loss_i.mean()
|
||||
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
out_dict["loss"] = loss
|
||||
return out_dict
|
||||
|
||||
142
src/diffusion/flow_matching/training_repa.py
Normal file
142
src/diffusion/flow_matching/training_repa.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
mask_ratio=0.0,
|
||||
mask_patch_size=2,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_patch_size = mask_patch_size
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
|
||||
patch_mask = (patch_mask < self.mask_ratio).float()
|
||||
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
|
||||
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
|
||||
|
||||
x_t = alpha*masked_x + sigma*noise
|
||||
v_t = dalpha*x + dsigma*noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
v_t_out, x0_out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
|
||||
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
mask_loss=mask_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
143
src/diffusion/pre_integral.py
Normal file
143
src/diffusion/pre_integral.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
|
||||
# lagrange interpolation
|
||||
def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 1
|
||||
Args:
|
||||
t1: timestepx
|
||||
v1: value field at t1
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1 = (int_t_end-int_t_start)
|
||||
return int1*v1, (int1/int1, )
|
||||
|
||||
def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 2
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
|
||||
int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
|
||||
int_sum = int1+int2
|
||||
return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
|
||||
|
||||
def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 3
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
t3: timestepz
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
v3: value field at t3
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1_denom = (t1-t2)*(t1-t3)
|
||||
int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
|
||||
int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
|
||||
int1 = (int1_end - int1_start)/int1_denom
|
||||
int2_denom = (t2-t1)*(t2-t3)
|
||||
int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
|
||||
int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
|
||||
int2 = (int2_end - int2_start)/int2_denom
|
||||
int3_denom = (t3-t1)*(t3-t2)
|
||||
int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
|
||||
int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
|
||||
int3 = (int3_end - int3_start)/int3_denom
|
||||
int_sum = int1+int2+int3
|
||||
return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
|
||||
|
||||
def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 4
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
t3: timestepz
|
||||
t4: timestepw
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
v3: value field at t3
|
||||
v4: value field at t4
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
|
||||
int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
|
||||
int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
|
||||
int1 = (int1_end - int1_start)/int1_denom
|
||||
int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
|
||||
int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
|
||||
int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
|
||||
int2 = (int2_end - int2_start)/int2_denom
|
||||
int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
|
||||
int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
|
||||
int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
|
||||
int3 = (int3_end - int3_start)/int3_denom
|
||||
int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
|
||||
int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
|
||||
int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
|
||||
int4 = (int4_end - int4_start)/int4_denom
|
||||
int_sum = int1+int2+int3+int4
|
||||
return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
|
||||
|
||||
|
||||
def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation
|
||||
Args:
|
||||
order: order of interpolation
|
||||
pre_vs: value field at pre_ts
|
||||
pre_ts: timesteps
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
order = min(order, len(pre_vs), len(pre_ts))
|
||||
if order == 1:
|
||||
return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 2:
|
||||
return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 3:
|
||||
return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 4:
|
||||
return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
else:
|
||||
raise ValueError('Invalid order')
|
||||
|
||||
|
||||
def polynomial_integral(coeffs, int_t_start, int_t_end):
|
||||
'''
|
||||
polynomial integral
|
||||
Args:
|
||||
coeffs: coefficients of the polynomial
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
orders = len(coeffs)
|
||||
int_val = 0
|
||||
for o in range(orders):
|
||||
int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
|
||||
return int_val
|
||||
|
||||
112
src/diffusion/stateful_flow_matching/adam_sampling.py
Normal file
112
src/diffusion/stateful_flow_matching/adam_sampling.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import math
|
||||
from src.diffusion.base.sampling import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.pre_integral import *
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def t2snr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2snr(t) for t in t]
|
||||
t = max(t, 1e-8)
|
||||
return (t/(1-t + 1e-8))
|
||||
|
||||
def t2logsnr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2logsnr(t) for t in t]
|
||||
t = max(t, 1e-3)
|
||||
return math.log(t/(1-t + 1e-3))
|
||||
|
||||
def t2isnr(t):
|
||||
return 1/t2snr(t)
|
||||
|
||||
def nop(t):
|
||||
return t
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdamLMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
state_refresh_rate: int = 1,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
|
||||
def _reparameterize_coeffs(self):
|
||||
solver_coeffs = [[] for _ in range(self.num_steps)]
|
||||
for i in range(0, self.num_steps):
|
||||
pre_vs = [1.0, ]*(i+1)
|
||||
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
||||
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
||||
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
||||
|
||||
order_annealing = self.order #self.num_steps - i
|
||||
order = min(self.order, i + 1, order_annealing)
|
||||
|
||||
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
||||
solver_coeffs[i] = coeffs
|
||||
self.solver_coeffs = solver_coeffs
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
state = None
|
||||
pred_trajectory = []
|
||||
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
||||
timedeltas = self.timedeltas
|
||||
solver_coeffs = self.solver_coeffs
|
||||
for i in range(self.num_steps):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
for j in range(order):
|
||||
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
||||
v = out
|
||||
dt = timedeltas[i]
|
||||
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
||||
x = self.step_fn(x, v, dt, s=0, w=0)
|
||||
t_cur += dt
|
||||
return x
|
||||
122
src/diffusion/stateful_flow_matching/bak/training_adv.py
Normal file
122
src/diffusion/stateful_flow_matching/bak/training_adv.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
adv_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
adv_feature.append(output)
|
||||
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
|
||||
real_feature = adv_feature.pop()
|
||||
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
|
||||
fake_feature = adv_feature.pop()
|
||||
handle.remove()
|
||||
|
||||
|
||||
real_score_gan = self.dis_head(real_feature.detach())
|
||||
fake_score_gan = self.dis_head(fake_feature.detach())
|
||||
fake_score = self.dis_head(fake_feature)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
127
src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
Normal file
127
src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
lpips_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.lpips_weight = lpips_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
with torch.no_grad():
|
||||
_, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
|
||||
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
|
||||
|
||||
real_score_gan = self.dis_head(real_features[-1].detach())
|
||||
fake_score_gan = self.dis_head(fake_features[-1].detach())
|
||||
fake_score = self.dis_head(fake_features[-1])
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
159
src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
Normal file
159
src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class MaskREPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
mask_groups=4,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.mask_groups = mask_groups
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
|
||||
mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
|
||||
random_seq = torch.randperm(length, device=device)
|
||||
for i in range(groups):
|
||||
group_start = (length+groups-1)//groups*i
|
||||
group_end = (length+groups-1)//groups*(i+1)
|
||||
group_random_seq = random_seq[group_start:group_end]
|
||||
y, x = torch.meshgrid(group_random_seq, group_random_seq)
|
||||
mask[:, y, x] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
mask_groups = random.randint(1, self.mask_groups)
|
||||
mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
|
||||
out, _ = net(x_t, t, y, mask=mask)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
179
src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
Normal file
179
src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, mul=1000):
|
||||
t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class BatchNormWithTimeEmbedding(nn.Module):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
# self.bn = nn.BatchNorm2d(num_features, affine=False)
|
||||
self.bn = nn.GroupNorm(16, num_features, affine=False)
|
||||
# self.bn = nn.SyncBatchNorm(num_features, affine=False)
|
||||
self.embedder = TimestepEmbedder(num_features * 2)
|
||||
# nn.init.zeros_(self.embedder.mlp[-1].weight)
|
||||
nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
|
||||
nn.init.zeros_(self.embedder.mlp[-1].bias)
|
||||
|
||||
def forward(self, x, t):
|
||||
embed = self.embedder(t)
|
||||
embed = embed[:, :, None, None]
|
||||
gamma, beta = embed.chunk(2, dim=1)
|
||||
gamma = 1.0 + gamma
|
||||
normed = self.bn(x)
|
||||
out = normed * gamma + beta
|
||||
return out
|
||||
|
||||
class DisBlock(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
|
||||
)
|
||||
self.norm = BatchNormWithTimeEmbedding(hidden_size)
|
||||
self.act = nn.SiLU()
|
||||
def forward(self, x, t):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x, t)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, num_blocks, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(num_blocks):
|
||||
self.blocks.append(
|
||||
DisBlock(
|
||||
in_channels=in_channels,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
)
|
||||
in_channels = hidden_size
|
||||
self.classifier = nn.Conv2d(
|
||||
kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
|
||||
)
|
||||
def forward(self, feature, t):
|
||||
B, C, H, W = feature.shape
|
||||
for block in self.blocks:
|
||||
feature = block(feature, t)
|
||||
out = self.classifier(feature).view(B, -1)
|
||||
out = out.sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_blocks=3,
|
||||
adv_in_channels=3,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
|
||||
self.discriminator = Discriminator(
|
||||
num_blocks=adv_blocks,
|
||||
in_channels=adv_in_channels*2,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = x_t + sigma * out
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
real_feature = torch.cat([x_t, x], dim=1)
|
||||
fake_feature = torch.cat([x_t, pred_x0], dim=1)
|
||||
|
||||
real_score_gan = self.discriminator(real_feature.detach(), t)
|
||||
fake_score_gan = self.discriminator(fake_feature.detach(), t)
|
||||
fake_score = self.discriminator(fake_feature, t)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
154
src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
Normal file
154
src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPAJiTTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
jit_deltas=0.01,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.jit_deltas = jit_deltas
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
|
||||
t2 = torch.clip(t2, 0, 1)
|
||||
alpha = self.scheduler.alpha(t2)
|
||||
dalpha = self.scheduler.dalpha(t2)
|
||||
sigma = self.scheduler.sigma(t2)
|
||||
dsigma = self.scheduler.dsigma(t2)
|
||||
x_t2 = alpha * x + noise * sigma
|
||||
v_t2 = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
_, s = net(x_t, t, y, only_s=True)
|
||||
out, _ = net(x_t2, t2, y, s)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t2)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class SelfConsistentTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
lpips_encoder_layer=4,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_encoder_layer = lpips_encoder_layer
|
||||
self.lpips_weight = lpips_weight
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
real_features = []
|
||||
def forward_hook(net, input, output):
|
||||
real_features.append(output)
|
||||
handles = []
|
||||
for i in range(self.lpips_encoder_layer):
|
||||
handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
|
||||
handles.append(handle)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
for handle in handles:
|
||||
handle.remove()
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + noise * sigma
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
_, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class SelfLPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
lpips_encoder_layer=4,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_encoder_layer = lpips_encoder_layer
|
||||
self.lpips_weight = lpips_weight
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + noise * sigma
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
with torch.no_grad():
|
||||
_, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
|
||||
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
|
||||
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
)
|
||||
return out
|
||||
78
src/diffusion/stateful_flow_matching/cm_sampling.py
Normal file
78
src/diffusion/stateful_flow_matching/cm_sampling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
last_step=None,
|
||||
step_fn=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_step = last_step
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
cfg_t = t_cur.repeat(batch_size*2)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
|
||||
x0 = x + v * (1-t_cur)
|
||||
alpha_next = self.scheduler.alpha(t_next)
|
||||
sigma_next = self.scheduler.sigma(t_next)
|
||||
x = alpha_next * x0 + sigma_next * torch.randn_like(x)
|
||||
# print(alpha_next, sigma_next)
|
||||
return x
|
||||
103
src/diffusion/stateful_flow_matching/sampling.py
Normal file
103
src/diffusion/stateful_flow_matching/sampling.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def sde_mean_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt + s * w * dt
|
||||
|
||||
def sde_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
||||
|
||||
def sde_preserve_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
39
src/diffusion/stateful_flow_matching/scheduling.py
Normal file
39
src/diffusion/stateful_flow_matching/scheduling.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class LinearScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return (t).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return (1-t).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
||||
|
||||
# SoTA for ImageNet!
|
||||
class GVPScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def w(self, t):
|
||||
return torch.sin(t)**2
|
||||
|
||||
class ConstScheduler(BaseScheduler):
|
||||
def w(self, t):
|
||||
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
||||
|
||||
from src.diffusion.ddpm.scheduling import VPScheduler
|
||||
class VPBetaScheduler(VPScheduler):
|
||||
def w(self, t):
|
||||
return self.beta(t).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
|
||||
149
src/diffusion/stateful_flow_matching/sharing_sampling.py
Normal file
149
src/diffusion/stateful_flow_matching/sharing_sampling.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
# init recompute
|
||||
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
||||
self.recompute_timesteps = list(range(self.num_steps))
|
||||
|
||||
def sharing_dp(self, net, noise, condition, uncondition):
|
||||
_, C, H, W = noise.shape
|
||||
B = 8
|
||||
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
|
||||
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
|
||||
template_uncondition = torch.full((B, ), 1000, device=condition.device)
|
||||
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
|
||||
states = torch.stack(state_list)
|
||||
N, B, L, C = states.shape
|
||||
states = states.view(N, B*L, C )
|
||||
states = states.permute(1, 0, 2)
|
||||
states = torch.nn.functional.normalize(states, dim=-1)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float64):
|
||||
sim = torch.bmm(states, states.transpose(1, 2))
|
||||
sim = torch.mean(sim, dim=0).cpu()
|
||||
error_map = (1-sim).tolist()
|
||||
|
||||
# init cum-error
|
||||
for i in range(1, self.num_steps):
|
||||
for j in range(0, i):
|
||||
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
|
||||
|
||||
# init dp and force 0 start
|
||||
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
||||
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
||||
for i in range(1, self.num_steps+1):
|
||||
C[1][i] = error_map[i - 1][0]
|
||||
P[1][i] = 0
|
||||
|
||||
# dp state
|
||||
for step in range(2, self.num_recompute_timesteps+1):
|
||||
for i in range(step, self.num_steps+1):
|
||||
min_value = 99999
|
||||
min_index = -1
|
||||
for j in range(step-1, i):
|
||||
value = C[step-1][j] + error_map[i-1][j]
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
min_index = j
|
||||
C[step][i] = min_value
|
||||
P[step][i] = min_index
|
||||
|
||||
# trace back
|
||||
timesteps = [self.num_steps,]
|
||||
for i in range(self.num_recompute_timesteps, 0, -1):
|
||||
idx = timesteps[-1]
|
||||
timesteps.append(P[i][idx])
|
||||
timesteps.reverse()
|
||||
|
||||
print("recompute timesteps solved by DP: ", timesteps)
|
||||
return timesteps[:-1]
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
pooled_state_list = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i in self.recompute_timesteps:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
|
||||
pooled_state_list.append(state)
|
||||
return x, pooled_state_list
|
||||
|
||||
def __call__(self, net, noise, condition, uncondition):
|
||||
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
||||
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
||||
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
||||
return denoised
|
||||
55
src/diffusion/stateful_flow_matching/training.py
Normal file
55
src/diffusion/stateful_flow_matching/training.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class FlowMatchingTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
122
src/diffusion/stateful_flow_matching/training_adv.py
Normal file
122
src/diffusion/stateful_flow_matching/training_adv.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
adv_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
adv_feature.append(output)
|
||||
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
|
||||
real_feature = adv_feature.pop()
|
||||
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
|
||||
fake_feature = adv_feature.pop()
|
||||
handle.remove()
|
||||
|
||||
|
||||
real_score_gan = self.dis_head(real_feature.detach())
|
||||
fake_score_gan = self.dis_head(fake_feature.detach())
|
||||
fake_score = self.dis_head(fake_feature)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
141
src/diffusion/stateful_flow_matching/training_distill_dino.py
Normal file
141
src/diffusion/stateful_flow_matching/training_distill_dino.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class DistillDINOTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
|
||||
_, s = net(x_t, t, y)
|
||||
src_feature = self.proj(s)
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
out = dict(
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
71
src/diffusion/stateful_flow_matching/training_lpips.py
Normal file
71
src/diffusion/stateful_flow_matching/training_lpips.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class LPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
# self.lpips = torch.compile(self.lpips)
|
||||
no_grad(self.lpips)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out*sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + lpips.mean()*self.lpips_weight,
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
||||
return
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class LPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = False
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
# self.lpips = torch.compile(self.lpips)
|
||||
no_grad(self.lpips)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
fm_weight = t*(1-t)**2/0.25
|
||||
lpips_weight = t
|
||||
|
||||
loss = (out - v_t)**2 * fm_weight[:, None, None, None]
|
||||
|
||||
pred_x0 = (x_t + out*sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + lpips.mean()*self.lpips_weight,
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
||||
return
|
||||
157
src/diffusion/stateful_flow_matching/training_repa.py
Normal file
157
src/diffusion/stateful_flow_matching/training_repa.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
|
||||
if getattr(net, "blocks", None) is not None:
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
else:
|
||||
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
170
src/diffusion/stateful_flow_matching/training_repa_lpips.py
Normal file
170
src/diffusion/stateful_flow_matching/training_repa_lpips.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPALPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
no_grad(self.lpips)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
if getattr(net, "blocks", None) is not None:
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
else:
|
||||
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
Reference in New Issue
Block a user