submit code
This commit is contained in:
60
src/diffusion/base/guidance.py
Normal file
60
src/diffusion/base/guidance.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
def simple_guidance_fn(out, cfg):
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = uncondition + cfg * (condtion - uncondition)
|
||||
return out
|
||||
|
||||
def c3_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
|
||||
return out
|
||||
|
||||
def c4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p05_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p10_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p15_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p20_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def p4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
31
src/diffusion/base/sampling.py
Normal file
31
src/diffusion/base/sampling.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Callable
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
class BaseSampler(nn.Module):
|
||||
def __init__(self,
|
||||
scheduler: BaseScheduler = None,
|
||||
guidance_fn: Callable = None,
|
||||
num_steps: int = 250,
|
||||
guidance: Union[float, List[float]] = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(BaseSampler, self).__init__()
|
||||
self.num_steps = num_steps
|
||||
self.guidance = guidance
|
||||
self.guidance_fn = guidance_fn
|
||||
self.scheduler = scheduler
|
||||
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, noise, condition, uncondition):
|
||||
denoised = self._impl_sampling(net, noise, condition, uncondition)
|
||||
return denoised
|
||||
|
||||
|
||||
32
src/diffusion/base/scheduling.py
Normal file
32
src/diffusion/base/scheduling.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
class BaseScheduler:
|
||||
def alpha(self, t) -> Tensor:
|
||||
...
|
||||
def sigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
...
|
||||
def dsigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha_over_alpha(self, t) -> Tensor:
|
||||
return self.dalpha(t) / self.alpha(t)
|
||||
|
||||
def dsigma_mul_sigma(self, t) -> Tensor:
|
||||
return self.dsigma(t)*self.sigma(t)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dalpha/(alpha + 1e-6)
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
|
||||
|
||||
def w(self, t):
|
||||
return self.sigma(t)
|
||||
29
src/diffusion/base/training.py
Normal file
29
src/diffusion/base/training.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseTrainer(nn.Module):
|
||||
def __init__(self,
|
||||
null_condition_p=0.1,
|
||||
log_var=False,
|
||||
):
|
||||
super(BaseTrainer, self).__init__()
|
||||
self.null_condition_p = null_condition_p
|
||||
self.log_var = log_var
|
||||
|
||||
def preproprocess(self, raw_iamges, x, condition, uncondition):
|
||||
bsz = x.shape[0]
|
||||
if self.null_condition_p > 0:
|
||||
mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
|
||||
mask = mask.expand_as(condition)
|
||||
condition[mask] = uncondition[mask]
|
||||
return raw_iamges, x, condition
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
|
||||
raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
|
||||
return self._impl_trainstep(net, ema_net, raw_images, x, condition)
|
||||
|
||||
Reference in New Issue
Block a user