Compare commits
10 Commits
3693640ca3
...
cf45afe325
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf45afe325 | ||
|
|
1d1b4d2913 | ||
|
|
ae7df43ecc | ||
|
|
2c16b8f423 | ||
|
|
8038c16bee | ||
|
|
598f7b40a2 | ||
|
|
3093a65151 | ||
|
|
910b764275 | ||
|
|
26c8e54edb | ||
|
|
99d92c94e7 |
@@ -13,6 +13,7 @@ We decouple diffusion transformer into encoder-decoder design, and surprisingly
|
||||
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
|
||||
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
|
||||
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
|
||||
|
||||
## Visualizations
|
||||

|
||||
## Checkpoints
|
||||
|
||||
116
configs/ddt_butterflies_b2_256.yaml
Normal file
116
configs/ddt_butterflies_b2_256.yaml
Normal file
@@ -0,0 +1,116 @@
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp ddt_butterflies_b2_256
|
||||
torch_hub_dir: null
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: ddt_butterflies
|
||||
name: *exp
|
||||
save_dir: workdirs
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 200000
|
||||
val_check_interval: 2000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
max_save_num: 64
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: false
|
||||
weight_path: stabilityai/sd-vae-ft-ema
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 12
|
||||
hidden_size: &hidden_dim 768
|
||||
num_blocks: 12
|
||||
num_encoder_blocks: 8
|
||||
num_classes: 1
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 4
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 1.0
|
||||
timeshift: 1.0
|
||||
state_refresh_rate: 1
|
||||
guidance_interval_min: 0.3
|
||||
guidance_interval_max: 1.0
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
last_step: 0.04
|
||||
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-3
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: hf_image
|
||||
train_root: ./data/butterflies
|
||||
test_nature_root: null
|
||||
test_gen_root: null
|
||||
train_image_size: 256
|
||||
train_batch_size: 128
|
||||
train_num_workers: 8
|
||||
train_prefetch_factor: 2
|
||||
train_hf_name: huggan/smithsonian_butterflies_subset
|
||||
train_hf_split: train
|
||||
train_hf_image_column: image
|
||||
train_hf_label_column: null
|
||||
train_hf_cache_dir: null
|
||||
eval_max_num_instances: 256
|
||||
pred_batch_size: 32
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1
|
||||
latent_shape:
|
||||
- 4
|
||||
- 32
|
||||
- 32
|
||||
@@ -6,4 +6,5 @@ jsonargparse[signatures]>=4.27.7
|
||||
torchvision
|
||||
timm
|
||||
accelerate
|
||||
gradio
|
||||
gradio
|
||||
datasets
|
||||
|
||||
@@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler):
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
assert self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
@@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
guidance = self.guidance
|
||||
out = self.guidance_fn(out, guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
|
||||
104
src/diffusion/flow_matching/training_disperse.py
Normal file
104
src/diffusion/flow_matching/training_disperse.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
def time_shift_fn(t, timeshift=1.0):
|
||||
return t/(t+(1-t)*timeshift)
|
||||
|
||||
|
||||
class DisperseTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
timeshift=1.0,
|
||||
align_layer=8,
|
||||
temperature=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.timeshift = timeshift
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.temperature = temperature
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32)
|
||||
t = time_shift_fn(base_t, self.timeshift).to(x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
feature = output
|
||||
if isinstance(feature, tuple):
|
||||
feature = feature[0] # mmdit
|
||||
src_feature.append(feature)
|
||||
|
||||
if getattr(net, "encoder", None) is not None:
|
||||
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
else:
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out = net(x_t, t, y)
|
||||
handle.remove()
|
||||
disperse_loss = 0.0
|
||||
world_size = torch.distributed.get_world_size()
|
||||
for sf in src_feature:
|
||||
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_sf, sf)
|
||||
gathered_sf = torch.cat(gathered_sf, dim=0)
|
||||
sf = gathered_sf.view(batch_size * world_size, -1)
|
||||
sf = sf.view(batch_size * world_size, -1)
|
||||
# normalize sf
|
||||
sf = sf / torch.norm(sf, dim=1, keepdim=True)
|
||||
distance = torch.nn.functional.pdist(sf, p=2) ** 2
|
||||
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
|
||||
disperse_loss += sf_disperse_distance.mean().log()
|
||||
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
disperse_loss=disperse_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
|
||||
)
|
||||
return out
|
||||
@@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler):
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate: int = 1,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
assert self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
@@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler):
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
guidance = self.guidance
|
||||
out = self.guidance_fn(out, guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
|
||||
@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
|
||||
var_transform_engine: VARTransformEngine = None,
|
||||
train_prefetch_factor=2,
|
||||
train_dataset: str = None,
|
||||
train_hf_name: str = None,
|
||||
train_hf_split: str = "train",
|
||||
train_hf_image_column: str = "image",
|
||||
train_hf_label_column: str = None,
|
||||
train_hf_cache_dir: str = None,
|
||||
eval_batch_size=32,
|
||||
eval_num_workers=4,
|
||||
eval_max_num_instances=50000,
|
||||
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
|
||||
self.train_root = train_root
|
||||
self.train_image_size = train_image_size
|
||||
self.train_dataset = train_dataset
|
||||
self.train_hf_name = train_hf_name
|
||||
self.train_hf_split = train_hf_split
|
||||
self.train_hf_image_column = train_hf_image_column
|
||||
self.train_hf_label_column = train_hf_label_column
|
||||
self.train_hf_cache_dir = train_hf_cache_dir
|
||||
# stupid data_convert override, just to make nebular happy
|
||||
self.train_batch_size = train_batch_size
|
||||
self.train_num_workers = train_num_workers
|
||||
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
|
||||
self.train_dataset = PixImageNet512(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "hf_image":
|
||||
from src.data.dataset.hf_image import HuggingFaceImageDataset
|
||||
if self.train_hf_name is None:
|
||||
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
|
||||
self.train_dataset = HuggingFaceImageDataset(
|
||||
name=self.train_hf_name,
|
||||
split=self.train_hf_split,
|
||||
image_column=self.train_hf_image_column,
|
||||
label_column=self.train_hf_label_column,
|
||||
resolution=self.train_image_size,
|
||||
cache_dir=self.train_hf_cache_dir,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("no such dataset")
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
|
||||
ema_tracker: Optional[EMACallable] = None,
|
||||
optimizer: OptimizerCallable = None,
|
||||
lr_scheduler: LRSchedulerCallable = None,
|
||||
compile: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
|
||||
# self.model_loader = ModelLoader()
|
||||
|
||||
self._strict_loading = False
|
||||
self._compile = compile
|
||||
|
||||
def configure_model(self) -> None:
|
||||
self.trainer.strategy.barrier()
|
||||
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
|
||||
no_grad(self.diffusion_sampler)
|
||||
no_grad(self.ema_denoiser)
|
||||
|
||||
# add compile to speed up
|
||||
if self._compile:
|
||||
self.denoiser = torch.compile(self.denoiser)
|
||||
self.ema_denoiser = torch.compile(self.ema_denoiser)
|
||||
|
||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||
return [ema_tracker]
|
||||
|
||||
@@ -251,11 +251,11 @@ class DiT(nn.Module):
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device, dtype):
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
@@ -289,7 +289,7 @@ class DiT(nn.Module):
|
||||
B, _, H, W = x.shape
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
x = self.x_embedder(x)
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||
B, L, C = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||
y = self.y_embedder(y).view(B, 1, C)
|
||||
@@ -298,4 +298,4 @@ class DiT(nn.Module):
|
||||
x = block(x, condition, pos, masks[i])
|
||||
x = self.final_layer(x, condition)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x
|
||||
return x
|
||||
|
||||
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user