submit code
This commit is contained in:
102
src/diffusion/ddpm/scheduling.py
Normal file
102
src/diffusion/ddpm/scheduling.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class DDPMScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.0001,
|
||||
beta_max=0.02,
|
||||
num_steps=1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_max = beta_max
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
|
||||
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
|
||||
self.sigmas_table = 1-self.alphas_table
|
||||
|
||||
|
||||
def beta(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.betas_table[t].view(-1, 1, 1, 1)
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
|
||||
class VPScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.1,
|
||||
beta_max=20,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_d = beta_max - beta_min
|
||||
def beta(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
|
||||
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
|
||||
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
return self.diffuse_coefficient(t)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user