submit code
This commit is contained in:
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DDIMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_num_steps=1000,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.train_num_steps = train_num_steps
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
|
||||
steps = torch.flip(steps, dims=[0])
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
t_next = t_next.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha = self.scheduler.alpha(t_cur)
|
||||
sigma_next = self.scheduler.sigma(t_next)
|
||||
alpha_next = self.scheduler.alpha(t_next)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
t = t_cur.repeat(2)
|
||||
out = net(cfg_x, t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
x0 = (x - sigma * out) / alpha
|
||||
x = alpha_next * x0 + sigma_next * out
|
||||
return x0
|
||||
102
src/diffusion/ddpm/scheduling.py
Normal file
102
src/diffusion/ddpm/scheduling.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class DDPMScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.0001,
|
||||
beta_max=0.02,
|
||||
num_steps=1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_max = beta_max
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
|
||||
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
|
||||
self.sigmas_table = 1-self.alphas_table
|
||||
|
||||
|
||||
def beta(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.betas_table[t].view(-1, 1, 1, 1)
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
|
||||
class VPScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.1,
|
||||
beta_max=20,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_d = beta_max - beta_min
|
||||
def beta(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
|
||||
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
|
||||
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
return self.diffuse_coefficient(t)
|
||||
|
||||
|
||||
|
||||
83
src/diffusion/ddpm/training.py
Normal file
83
src/diffusion/ddpm/training.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class VPTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t*self.train_max_t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - noise)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class DDPMTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn: Callable = constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
t = torch.randint(0, self.train_max_t, (batch_size,))
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight * (out - noise) ** 2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/ddpm/vp_sampling.py
Normal file
59
src/diffusion/ddpm/vp_sampling.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
from typing import Callable
|
||||
|
||||
def ode_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
|
||||
|
||||
def sde_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VPEulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_max_t=1000,
|
||||
guidance_fn: Callable = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.guidance_fn = guidance_fn
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
|
||||
steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
beta = self.scheduler.beta(t_cur)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
|
||||
eps = self.guidance_fn(out, self.guidance)
|
||||
if i < self.num_steps -1 :
|
||||
x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
|
||||
x = self.step_fn(x, eps, beta, sigma, dt)
|
||||
else:
|
||||
x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
|
||||
return x
|
||||
Reference in New Issue
Block a user