Compare commits

..

10 Commits

Author SHA1 Message Date
game-loader
cf45afe325 feat(data): add support for Hugging Face datasets 2026-01-16 14:25:32 +08:00
wang shuai
1d1b4d2913 Update README.md 2025-08-22 10:49:18 +08:00
wangshuai6
ae7df43ecc torch.compile 2025-07-03 20:23:59 +08:00
wang shuai
2c16b8f423 Update improved_dit.py 2025-07-03 20:21:04 +08:00
wangshuai6
8038c16bee disperse loss 2025-06-14 12:14:01 +08:00
wangshuai6
598f7b40a2 disperse loss 2025-06-13 10:41:31 +08:00
wangshuai6
3093a65151 disperse loss 2025-06-12 22:18:15 +08:00
wangshuai6
910b764275 pixelddt-t2i update 2025-06-05 09:56:03 +08:00
wangshuai6
26c8e54edb pixelddt-t2i update 2025-06-05 09:55:43 +08:00
wangshuai6
99d92c94e7 fix bugs(admas timedeltas) 2025-05-20 12:30:40 +08:00
10 changed files with 277 additions and 14 deletions

View File

@@ -13,6 +13,7 @@ We decouple diffusion transformer into encoder-decoder design, and surprisingly
* We achieves **1.26 FID** on ImageNet256x256 Benchmark with DDT-XL/2(22en6de).
* We achieves **1.28 FID** on ImageNet512x512 Benchmark with DDT-XL/2(22en6de).
* As a byproduct, our DDT can reuse encoder among adjacent steps to accelerate inference.
## Visualizations
![](./figs/teaser.png)
## Checkpoints

View File

@@ -0,0 +1,116 @@
seed_everything: true
tags:
exp: &exp ddt_butterflies_b2_256
torch_hub_dir: null
huggingface_cache_dir: null
trainer:
default_root_dir: workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ddt_butterflies
name: *exp
save_dir: workdirs
num_sanity_val_steps: 0
max_steps: 200000
val_check_interval: 2000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
max_save_num: 64
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: false
weight_path: stabilityai/sd-vae-ft-ema
denoiser:
class_path: src.models.denoiser.decoupled_improved_dit.DDT
init_args:
in_channels: 4
patch_size: 2
num_groups: 12
hidden_size: &hidden_dim 768
num_blocks: 12
num_encoder_blocks: 8
num_classes: 1
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1
diffusion_trainer:
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 4
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 1.0
timeshift: 1.0
state_refresh_rate: 1
guidance_interval_min: 0.3
guidance_interval_max: 1.0
scheduler: *scheduler
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
last_step: 0.04
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-3
betas:
- 0.9
- 0.95
weight_decay: 0.0
data:
train_dataset: hf_image
train_root: ./data/butterflies
test_nature_root: null
test_gen_root: null
train_image_size: 256
train_batch_size: 128
train_num_workers: 8
train_prefetch_factor: 2
train_hf_name: huggan/smithsonian_butterflies_subset
train_hf_split: train
train_hf_image_column: image
train_hf_label_column: null
train_hf_cache_dir: null
eval_max_num_instances: 256
pred_batch_size: 32
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1
latent_shape:
- 4
- 32
- 32

View File

@@ -7,3 +7,4 @@ torchvision
timm
accelerate
gradio
datasets

View File

@@ -41,23 +41,25 @@ class AdamLMSampler(BaseSampler):
self,
order: int = 2,
timeshift: float = 1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
assert self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
self.timesteps = shift_respace_fn(timesteps, timeshift)
self.timedeltas = self.timesteps[1:] - self.timesteps[:-1]
self._reparameterize_coeffs()
@@ -93,7 +95,11 @@ class AdamLMSampler(BaseSampler):
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t, cfg_condition)
out = self.guidance_fn(out, self.guidance)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
guidance = self.guidance
out = self.guidance_fn(out, guidance)
else:
out = self.guidance_fn(out, 1.0)
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])

View File

@@ -0,0 +1,104 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
def time_shift_fn(t, timeshift=1.0):
return t/(t+(1-t)*timeshift)
class DisperseTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
timeshift=1.0,
align_layer=8,
temperature=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.timeshift = timeshift
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.temperature = temperature
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=torch.float32).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=torch.float32)
t = time_shift_fn(base_t, self.timeshift).to(x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
feature = output
if isinstance(feature, tuple):
feature = feature[0] # mmdit
src_feature.append(feature)
if getattr(net, "encoder", None) is not None:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out = net(x_t, t, y)
handle.remove()
disperse_loss = 0.0
world_size = torch.distributed.get_world_size()
for sf in src_feature:
gathered_sf = [torch.zeros_like(sf) for _ in range(world_size)]
torch.distributed.all_gather(gathered_sf, sf)
gathered_sf = torch.cat(gathered_sf, dim=0)
sf = gathered_sf.view(batch_size * world_size, -1)
sf = sf.view(batch_size * world_size, -1)
# normalize sf
sf = sf / torch.norm(sf, dim=1, keepdim=True)
distance = torch.nn.functional.pdist(sf, p=2) ** 2
sf_disperse_distance = torch.exp(-distance / self.temperature) + 1e-5
disperse_loss += sf_disperse_distance.mean().log()
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
disperse_loss=disperse_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*disperse_loss.mean(),
)
return out

View File

@@ -41,22 +41,24 @@ class AdamLMSampler(BaseSampler):
self,
order: int = 2,
timeshift: float = 1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate: int = 1,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
self.state_refresh_rate = state_refresh_rate
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
assert self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
@@ -98,7 +100,11 @@ class AdamLMSampler(BaseSampler):
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
out = self.guidance_fn(out, self.guidance)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
guidance = self.guidance
out = self.guidance_fn(out, guidance)
else:
out = self.guidance_fn(out, 1.0)
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])

View File

@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
var_transform_engine: VARTransformEngine = None,
train_prefetch_factor=2,
train_dataset: str = None,
train_hf_name: str = None,
train_hf_split: str = "train",
train_hf_image_column: str = "image",
train_hf_label_column: str = None,
train_hf_cache_dir: str = None,
eval_batch_size=32,
eval_num_workers=4,
eval_max_num_instances=50000,
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
self.train_root = train_root
self.train_image_size = train_image_size
self.train_dataset = train_dataset
self.train_hf_name = train_hf_name
self.train_hf_split = train_hf_split
self.train_hf_image_column = train_hf_image_column
self.train_hf_label_column = train_hf_label_column
self.train_hf_cache_dir = train_hf_cache_dir
# stupid data_convert override, just to make nebular happy
self.train_batch_size = train_batch_size
self.train_num_workers = train_num_workers
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
self.train_dataset = PixImageNet512(
root=self.train_root,
)
elif self.train_dataset == "hf_image":
from src.data.dataset.hf_image import HuggingFaceImageDataset
if self.train_hf_name is None:
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
self.train_dataset = HuggingFaceImageDataset(
name=self.train_hf_name,
split=self.train_hf_split,
image_column=self.train_hf_image_column,
label_column=self.train_hf_label_column,
resolution=self.train_image_size,
cache_dir=self.train_hf_cache_dir,
)
else:
raise NotImplementedError("no such dataset")

View File

@@ -34,6 +34,7 @@ class LightningModel(pl.LightningModule):
ema_tracker: Optional[EMACallable] = None,
optimizer: OptimizerCallable = None,
lr_scheduler: LRSchedulerCallable = None,
compile: bool = False
):
super().__init__()
self.vae = vae
@@ -48,6 +49,7 @@ class LightningModel(pl.LightningModule):
# self.model_loader = ModelLoader()
self._strict_loading = False
self._compile = compile
def configure_model(self) -> None:
self.trainer.strategy.barrier()
@@ -61,6 +63,11 @@ class LightningModel(pl.LightningModule):
no_grad(self.diffusion_sampler)
no_grad(self.ema_denoiser)
# add compile to speed up
if self._compile:
self.denoiser = torch.compile(self.denoiser)
self.ema_denoiser = torch.compile(self.ema_denoiser)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
return [ema_tracker]

View File

@@ -251,11 +251,11 @@ class DiT(nn.Module):
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype):
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype)
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
@@ -289,7 +289,7 @@ class DiT(nn.Module):
B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C)

Binary file not shown.