submit code
This commit is contained in:
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
from src.diffusion.base.sampling import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.pre_integral import *
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def t2snr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2snr(t) for t in t]
|
||||
t = max(t, 1e-8)
|
||||
return (t/(1-t + 1e-8))
|
||||
|
||||
def t2logsnr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2logsnr(t) for t in t]
|
||||
t = max(t, 1e-3)
|
||||
return math.log(t/(1-t + 1e-3))
|
||||
|
||||
def t2isnr(t):
|
||||
return 1/t2snr(t)
|
||||
|
||||
def nop(t):
|
||||
return t
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdamLMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
|
||||
def _reparameterize_coeffs(self):
|
||||
solver_coeffs = [[] for _ in range(self.num_steps)]
|
||||
for i in range(0, self.num_steps):
|
||||
pre_vs = [1.0, ]*(i+1)
|
||||
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
||||
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
||||
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
||||
|
||||
order_annealing = self.order #self.num_steps - i
|
||||
order = min(self.order, i + 1, order_annealing)
|
||||
|
||||
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
||||
solver_coeffs[i] = coeffs
|
||||
self.solver_coeffs = solver_coeffs
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
pred_trajectory = []
|
||||
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
||||
timedeltas = self.timedeltas
|
||||
solver_coeffs = self.solver_coeffs
|
||||
for i in range(self.num_steps):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
for j in range(order):
|
||||
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
||||
v = out
|
||||
dt = timedeltas[i]
|
||||
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
||||
x = self.step_fn(x, v, dt, s=0, w=0)
|
||||
t_cur += dt
|
||||
return x
|
||||
179
src/diffusion/flow_matching/sampling.py
Normal file
179
src/diffusion/flow_matching/sampling.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def sde_mean_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt + s * w * dt
|
||||
|
||||
def sde_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
||||
|
||||
def sde_preserve_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
|
||||
|
||||
class HeunSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler = None,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
exact_henu=False,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scheduler = scheduler
|
||||
self.exact_henu = exact_henu
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Henu sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
v_hat, s_hat = 0.0, 0.0
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
t_hat = t_next
|
||||
t_hat = t_hat.repeat(batch_size)
|
||||
sigma_hat = self.scheduler.sigma(t_hat)
|
||||
alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
|
||||
dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
|
||||
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
if i == 0 or self.exact_henu:
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t_cur = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t_cur, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
|
||||
else:
|
||||
v = v_hat
|
||||
s = s_hat
|
||||
x_hat = self.step_fn(x, v, dt, s=s, w=w)
|
||||
# henu correct
|
||||
if i < self.num_steps -1:
|
||||
cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
|
||||
cfg_t_hat = t_hat.repeat(2)
|
||||
out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v_hat = out
|
||||
s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
|
||||
v = (v + v_hat) / 2
|
||||
s = (s + s_hat) / 2
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
39
src/diffusion/flow_matching/scheduling.py
Normal file
39
src/diffusion/flow_matching/scheduling.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class LinearScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return (t).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return (1-t).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
||||
|
||||
# SoTA for ImageNet!
|
||||
class GVPScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def w(self, t):
|
||||
return torch.sin(t)**2
|
||||
|
||||
class ConstScheduler(BaseScheduler):
|
||||
def w(self, t):
|
||||
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
||||
|
||||
from src.diffusion.ddpm.scheduling import VPScheduler
|
||||
class VPBetaScheduler(VPScheduler):
|
||||
def w(self, t):
|
||||
return self.beta(t).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
|
||||
55
src/diffusion/flow_matching/training.py
Normal file
55
src/diffusion/flow_matching/training.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class FlowMatchingTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/flow_matching/training_cos.py
Normal file
59
src/diffusion/flow_matching/training_cos.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class COSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class PyramidTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
|
||||
output_pyramid = []
|
||||
def feature_hook(module, input, output):
|
||||
output_pyramid.extend(output)
|
||||
handle = net.decoder.register_forward_hook(feature_hook)
|
||||
net(x_t, t, y)
|
||||
handle.remove()
|
||||
|
||||
loss = 0.0
|
||||
out_dict = dict()
|
||||
|
||||
cur_v_t = v_t
|
||||
for i in range(len(output_pyramid)):
|
||||
cur_out = output_pyramid[i]
|
||||
loss_i = (cur_v_t - cur_out) ** 2
|
||||
loss += loss_i.mean()
|
||||
out_dict["loss_{}".format(i)] = loss_i.mean()
|
||||
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
out_dict["loss"] = loss
|
||||
return out_dict
|
||||
|
||||
142
src/diffusion/flow_matching/training_repa.py
Normal file
142
src/diffusion/flow_matching/training_repa.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
mask_ratio=0.0,
|
||||
mask_patch_size=2,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_patch_size = mask_patch_size
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
|
||||
patch_mask = (patch_mask < self.mask_ratio).float()
|
||||
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
|
||||
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
|
||||
|
||||
x_t = alpha*masked_x + sigma*noise
|
||||
v_t = dalpha*x + dsigma*noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
v_t_out, x0_out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
|
||||
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
mask_loss=mask_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
Reference in New Issue
Block a user