submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/DDT.iml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/DDT.iml" filepath="$PROJECT_DIR$/.idea/DDT.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@@ -0,0 +1,108 @@
# lightning.pytorch==2.4.0
seed_everything: true
tags:
exp: &exp repa_flatten_condit22_dit6_fixt_xl
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
huggingface_cache_dir: null
trainer:
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: universal_flow
name: *exp
num_sanity_val_steps: 0
max_steps: 4000000
val_check_interval: 4000000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
plugins:
- src.plugins.bd_env.BDEnvironment
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: true
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
denoiser:
class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
init_args:
in_channels: 4
patch_size: 2
num_groups: 16
hidden_size: &hidden_dim 1152
num_blocks: 28
num_cond_blocks: 22
num_classes: 1000
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1000
diffusion_trainer:
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 8
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 2.0
timeshift: 1.5
state_refresh_rate: 1
guidance_interval_min: 0.3
guidance_interval_max: 1.0
scheduler: *scheduler
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
last_step: 0.04
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-4
betas:
- 0.9
- 0.95
weight_decay: 0.0
data:
train_dataset: imagenet256
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
train_image_size: 256
train_batch_size: 16
eval_max_num_instances: 50000
pred_batch_size: 64
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1000
latent_shape:
- 4
- 32
- 32

View File

@@ -0,0 +1,108 @@
# lightning.pytorch==2.4.0
seed_everything: true
tags:
exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
huggingface_cache_dir: null
trainer:
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: universal_flow
name: *exp
num_sanity_val_steps: 0
max_steps: 4000000
val_check_interval: 4000000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
plugins:
- src.plugins.bd_env.BDEnvironment
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: true
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
denoiser:
class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
init_args:
in_channels: 4
patch_size: 2
num_groups: 16
hidden_size: &hidden_dim 1152
num_blocks: 28
num_cond_blocks: 22
num_classes: 1000
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1000
diffusion_trainer:
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 8
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 3.0
state_refresh_rate: 1
guidance_interval_min: 0.3
guidance_interval_max: 1.0
timeshift: 1.0
last_step: 0.04
scheduler: *scheduler
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-4
betas:
- 0.9
- 0.95
weight_decay: 0.0
data:
train_dataset: imagenet512
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
train_image_size: 512
train_batch_size: 16
eval_max_num_instances: 50000
pred_batch_size: 32
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1000
latent_shape:
- 4
- 64
- 64

View File

@@ -0,0 +1,99 @@
# lightning.pytorch==2.4.0
seed_everything: true
tags:
exp: &exp repa_flatten_dit_fixt_large
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
huggingface_cache_dir: null
trainer:
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: universal_flow
name: *exp
num_sanity_val_steps: 0
max_steps: 400000
val_check_interval: 100000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
plugins:
- src.plugins.bd_env.BDEnvironment
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: true
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
denoiser:
class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
init_args:
in_channels: 4
patch_size: 2
num_groups: 16
hidden_size: &hidden_dim 1024
num_blocks: 24
num_classes: 1000
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1000
diffusion_trainer:
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 8
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 1.00
scheduler: *scheduler
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-4
weight_decay: 0.0
data:
train_dataset: imagenet256
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
train_image_size: 256
train_batch_size: 32
eval_max_num_instances: 50000
pred_batch_size: 64
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1000
latent_shape:
- 4
- 32
- 32

View File

@@ -0,0 +1,99 @@
# lightning.pytorch==2.4.0
seed_everything: true
tags:
exp: &exp repa_flatten_dit_fixt_xl
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
huggingface_cache_dir: null
trainer:
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: universal_flow
name: *exp
num_sanity_val_steps: 0
max_steps: 400000
val_check_interval: 100000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
plugins:
- src.plugins.bd_env.BDEnvironment
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: true
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
denoiser:
class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
init_args:
in_channels: 4
patch_size: 2
num_groups: 16
hidden_size: &hidden_dim 1152
num_blocks: 28
num_classes: 1000
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1000
diffusion_trainer:
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 8
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 1.00
scheduler: *scheduler
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-4
weight_decay: 0.0
data:
train_dataset: imagenet256
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
train_image_size: 256
train_batch_size: 32
eval_max_num_instances: 50000
pred_batch_size: 64
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1000
latent_shape:
- 4
- 32
- 32

86
main.py Normal file
View File

@@ -0,0 +1,86 @@
import time
from typing import Any, Union
import pylab as pl
from src.utils.patch_bugs import *
import os
import torch
from lightning import Trainer, LightningModule
from src.lightning_data import DataModule
from src.lightning_model import LightningModel
from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback
import logging
logger = logging.getLogger("lightning.pytorch")
# log_path = os.path.join( f"log.txt")
# logger.addHandler(logging.FileHandler(log_path))
class ReWriteRootSaveConfigCallback(SaveConfigCallback):
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
stamp = time.strftime('%y%m%d%H%M')
file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml")
self.parser.save(
self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
class ReWriteRootDirCli(LightningCLI):
def before_instantiate_classes(self) -> None:
super().before_instantiate_classes()
config_trainer = self._get(self.config, "trainer", default={})
# predict path & logger check
if self.subcommand == "predict":
config_trainer.logger = None
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
class TagsClass:
def __init__(self, exp:str):
...
parser.add_class_arguments(TagsClass, nested_key="tags")
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
super().add_default_arguments_to_parser(parser)
parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),)
parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),)
def instantiate_trainer(self, **kwargs: Any) -> Trainer:
config_trainer = self._get(self.config_init, "trainer", default={})
default_root_dir = config_trainer.get("default_root_dir", None)
if default_root_dir is None:
default_root_dir = os.path.join(os.getcwd(), "workdirs")
dirname = ""
for v, k in self._get(self.config, "tags", default={}).items():
dirname += f"{v}_{k}"
default_root_dir = os.path.join(default_root_dir, dirname)
is_resume = self._get(self.config_init, "ckpt_path", default=None)
if os.path.exists(default_root_dir) and "debug" not in default_root_dir:
if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume:
raise FileExistsError(f"{default_root_dir} already exists")
config_trainer.default_root_dir = default_root_dir
trainer = super().instantiate_trainer(**kwargs)
if trainer.is_global_zero:
os.makedirs(default_root_dir, exist_ok=True)
return trainer
def instantiate_classes(self) -> None:
torch_hub_dir = self._get(self.config, "torch_hub_dir")
huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir")
if huggingface_cache_dir is not None:
os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir
if torch_hub_dir is not None:
os.environ["TORCH_HOME"] = torch_hub_dir
torch.hub.set_dir(torch_hub_dir)
super().instantiate_classes()
if __name__ == "__main__":
cli = ReWriteRootDirCli(LightningModel, DataModule,
auto_configure_optimizers=False,
save_config_callback=ReWriteRootSaveConfigCallback,
save_config_kwargs={"overwrite": True})

10
main.sh Normal file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
export NCCL_HOSTID=${MY_POD_NAME}
export MASTER_ADDR=${ARNOLD_WORKER_0_HOST}
export MASTER_PORT=${ARNOLD_WORKER_0_PORT}
export NODE_RANK=${ARNOLD_ID}
export NUM_NODES=${ARNOLD_WORKER_NUM}
python3 main.py fit -c $1 --trainer.num_nodes $NUM_NODES
# for pid in $(ps -ef | grep "yaml" | grep -v "grep" | awk '{print $2}'); do kill -9 $pid; done

3
requirements.txt Normal file
View File

@@ -0,0 +1,3 @@
lightning==2.5.0.post0
omegaconf==2.3.0
jsonargparse[signatures]>=4.27.7

0
src/__init__.py Normal file
View File

View File

22
src/callbacks/grad.py Normal file
View File

@@ -0,0 +1,22 @@
import torch
import lightning.pytorch as pl
from lightning.pytorch.utilities import grad_norm
from torch.optim import Optimizer
class GradientMonitor(pl.Callback):
"""Logs the gradient norm"""
def __init__(self, norm_type: int = 2):
norm_type = float(norm_type)
if norm_type <= 0:
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
self.norm_type = norm_type
def on_before_optimizer_step(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
optimizer: Optimizer
) -> None:
norms = grad_norm(pl_module, norm_type=self.norm_type)
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})

View File

@@ -0,0 +1,25 @@
import os.path
from typing import Optional, Dict, Any
import lightning.pytorch as pl
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from soupsieve.util import lower
class CheckpointHook(ModelCheckpoint):
"""Save checkpoint with only the incremental part of the model"""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.dirpath = trainer.default_root_dir
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
pl_module.strict_loading = False
def on_save_checkpoint(
self, trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
checkpoint: Dict[str, Any]
) -> None:
del checkpoint["callbacks"]
# def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
# if not "debug" in self.exception_ckpt_path:
# trainer.save_checkpoint(self.exception_ckpt_path)

View File

@@ -0,0 +1,105 @@
import lightning.pytorch as pl
from lightning.pytorch import Callback
import os.path
import numpy
from PIL import Image
from typing import Sequence, Any, Dict
from concurrent.futures import ThreadPoolExecutor
from lightning.pytorch.utilities.types import STEP_OUTPUT
from lightning_utilities.core.rank_zero import rank_zero_info
def process_fn(image, path):
Image.fromarray(image).save(path)
class SaveImagesHook(Callback):
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
self.save_dir = save_dir
self.max_save_num = max_save_num
self.compressed = compressed
def save_start(self, target_dir):
self.target_dir = target_dir
self.executor_pool = ThreadPoolExecutor(max_workers=8)
if not os.path.exists(self.target_dir):
os.makedirs(self.target_dir, exist_ok=True)
else:
if os.listdir(target_dir) and "debug" not in str(target_dir):
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
self.samples = []
self._have_saved_num = 0
rank_zero_info(f"Save images to {self.target_dir}")
def save_image(self, images, filenames):
images = images.permute(0, 2, 3, 1).cpu().numpy()
for sample, filename in zip(images, filenames):
if isinstance(filename, Sequence):
filename = filename[0]
path = f'{self.target_dir}/{filename}'
if self._have_saved_num >= self.max_save_num:
break
self.executor_pool.submit(process_fn, sample, path)
self._have_saved_num += 1
def process_batch(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: STEP_OUTPUT,
batch: Any,
) -> None:
b, c, h, w = samples.shape
xT, y, metadata = batch
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
self.save_image(samples, metadata)
if trainer.is_global_zero:
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
self.samples.append(all_samples)
def save_end(self):
if self.compressed and len(self.samples) > 0:
samples = numpy.concatenate(self.samples)
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
self.executor_pool.shutdown(wait=True)
self.samples = []
self.target_dir = None
self._have_saved_num = 0
self.executor_pool = None
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
self.save_start(target_dir)
def on_validation_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, outputs, batch)
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
self.save_start(target_dir)
def on_predict_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
samples: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.process_batch(trainer, pl_module, samples, batch)
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.save_end()

View File

@@ -0,0 +1,79 @@
from typing import Any, Dict
import torch
import torch.nn as nn
import threading
import lightning.pytorch as pl
from lightning.pytorch import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT
from src.utils.copy import swap_tensors
class SimpleEMA(Callback):
def __init__(self, net:nn.Module, ema_net:nn.Module,
decay: float = 0.9999,
every_n_steps: int = 1,
eval_original_model:bool = False
):
super().__init__()
self.decay = decay
self.every_n_steps = every_n_steps
self.eval_original_model = eval_original_model
self._stream = torch.cuda.Stream()
self.net_params = list(net.parameters())
self.ema_params = list(ema_net.parameters())
def swap_model(self):
for ema_p, p, in zip(self.ema_params, self.net_params):
swap_tensors(ema_p, p)
def ema_step(self):
@torch.no_grad()
def ema_update(ema_model_tuple, current_model_tuple, decay):
torch._foreach_mul_(ema_model_tuple, decay)
torch._foreach_add_(
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
)
if self._stream is not None:
self._stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._stream):
ema_update(self.ema_params, self.net_params, self.decay)
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
if trainer.global_step % self.every_n_steps == 0:
self.ema_step()
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.eval_original_model:
self.swap_model()
def state_dict(self) -> Dict[str, Any]:
return {
"decay": self.decay,
"every_n_steps": self.every_n_steps,
"eval_original_model": self.eval_original_model,
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.decay = state_dict["decay"]
self.every_n_steps = state_dict["every_n_steps"]
self.eval_original_model = state_dict["eval_original_model"]

1
src/data/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

View File

@@ -0,0 +1,11 @@
from typing import Callable
from torchvision.datasets import CelebA
class LocalDataset(CelebA):
def __init__(self, root:str, ):
super(LocalDataset, self).__init__(root, "train")
def __getitem__(self, idx):
data = super().__getitem__(idx)
return data

View File

@@ -0,0 +1,82 @@
import torch
from PIL import Image
from torchvision.datasets import ImageFolder
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import Normalize
from src.data.dataset.metric_dataset import CenterCrop
class LocalCachedDataset(ImageFolder):
def __init__(self, root, resolution=256):
super().__init__(root)
self.transform = CenterCrop(resolution)
self.cache_root = None
def load_latent(self, latent_path):
pk_data = torch.load(latent_path)
mean = pk_data['mean'].to(torch.float32)
logvar = pk_data['logvar'].to(torch.float32)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
latent = mean + torch.randn_like(mean) * std
return latent
def __getitem__(self, idx: int):
image_path, target = self.samples[idx]
latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
raw_image = Image.open(image_path).convert('RGB')
raw_image = self.transform(raw_image)
raw_image = to_tensor(raw_image)
if self.cache_root is not None:
latent = self.load_latent(latent_path)
else:
latent = raw_image
return raw_image, latent, target
class ImageNet256(LocalCachedDataset):
def __init__(self, root, ):
super().__init__(root, 256)
self.cache_root = root + "_256_latent"
class ImageNet512(LocalCachedDataset):
def __init__(self, root, ):
super().__init__(root, 512)
self.cache_root = root + "_512_latent"
class PixImageNet(ImageFolder):
def __init__(self, root, resolution=256):
super().__init__(root)
self.transform = CenterCrop(resolution)
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def __getitem__(self, idx: int):
image_path, target = self.samples[idx]
raw_image = Image.open(image_path).convert('RGB')
raw_image = self.transform(raw_image)
raw_image = to_tensor(raw_image)
normalized_image = self.normalize(raw_image)
return raw_image, normalized_image, target
class PixImageNet64(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 64)
class PixImageNet128(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 128)
class PixImageNet256(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 256)
class PixImageNet512(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 512)

View File

@@ -0,0 +1,82 @@
import pathlib
import torch
import random
import numpy as np
from torchvision.io.image import read_image
import torchvision.transforms as tvtf
from torch.utils.data import Dataset
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
from PIL import Image
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
def test_collate(batch):
return torch.stack(batch)
class ImageDataset(Dataset):
def __init__(self, root, image_size=(224, 224)):
self.root = pathlib.Path(root)
images = []
for ext in IMG_EXTENSIONS:
images.extend(self.root.rglob(ext))
random.shuffle(images)
self.images = list(map(lambda x: str(x), images))
self.transform = tvtf.Compose(
[
CenterCrop(image_size[0]),
tvtf.ToTensor(),
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
]
)
self.size = image_size
def __getitem__(self, idx):
try:
image = Image.open(self.images[idx])
image = self.transform(image)
except Exception as e:
print(self.images[idx])
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
# print(image)
metadata = dict(
path = self.images[idx],
root = self.root,
)
return image #, metadata
def __len__(self):
return len(self.images)

41
src/data/dataset/randn.py Normal file
View File

@@ -0,0 +1,41 @@
import os.path
import random
import torch
from torch.utils.data import Dataset
class RandomNDataset(Dataset):
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
self.selected_classes = selected_classes
if selected_classes is not None:
num_classes = len(selected_classes)
max_num_instances = 10*num_classes
self.num_classes = num_classes
self.seeds = seeds
if seeds is not None:
self.max_num_instances = len(seeds)*num_classes
self.num_seeds = len(seeds)
else:
self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
self.max_num_instances = self.num_seeds*num_classes
self.latent_shape = latent_shape
def __getitem__(self, idx):
label = idx // self.num_seeds
if self.selected_classes:
label = self.selected_classes[label]
seed = random.randint(0, 1<<31) #idx % self.num_seeds
if self.seeds is not None:
seed = self.seeds[idx % self.num_seeds]
# cls_dir = os.path.join(self.root, f"{label}")
filename = f"{label}_{seed}.png",
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
return latent, label, filename
def __len__(self):
return self.max_num_instances

145
src/data/var_training.py Normal file
View File

@@ -0,0 +1,145 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor
from typing import List
from PIL import Image
import torch
import random
import numpy as np
import copy
import torchvision.transforms.functional as tvtf
from src.models.vae import uint82fp
def center_crop_arr(pil_image, width, height):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = max(width / pil_image.size[0], height / pil_image.size[1])
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = random.randint(0, (arr.shape[0] - height))
crop_x = random.randint(0, (arr.shape[1] - width))
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
def process_fn(width, height, data, hflip=0.5):
image, label = data
if random.uniform(0, 1) > hflip: # hflip
image = tvtf.hflip(image)
image = center_crop_arr(image, width, height) # crop
image = np.array(image).transpose(2, 0, 1)
return image, label
class VARCandidate:
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
self.aspect_ratio = aspect_ratio
self.width = int(width)
self.height = int(height)
self.buffer = buffer
self.max_buffer_size = max_buffer_size
def add_sample(self, data):
self.buffer.append(data)
self.buffer = self.buffer[-self.max_buffer_size:]
def ready(self, batch_size):
return len(self.buffer) >= batch_size
def get_batch(self, batch_size):
batch = self.buffer[:batch_size]
self.buffer = self.buffer[batch_size:]
batch = [copy.deepcopy(b.result()) for b in batch]
x, y = zip(*batch)
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
x = list(map(uint82fp, x))
return x, y
class VARTransformEngine:
def __init__(self,
base_image_size,
num_aspect_ratios,
min_aspect_ratio,
max_aspect_ratio,
num_workers = 8,
):
self.base_image_size = base_image_size
self.num_aspect_ratios = num_aspect_ratios
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
self.aspect_ratios = self.aspect_ratios.tolist()
self.candidates_pool = []
for i in range(self.num_aspect_ratios):
candidate = VARCandidate(
aspect_ratio=self.aspect_ratios[i],
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
buffer=[],
max_buffer_size=1024
)
self.candidates_pool.append(candidate)
self.default_candidate = VARCandidate(
aspect_ratio=1.0,
width=self.base_image_size,
height=self.base_image_size,
buffer=[],
max_buffer_size=1024,
)
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
self._prefill_count = 100
def find_candidate(self, data):
image = data[0]
aspect_ratio = image.size[0] / image.size[1]
min_distance = 1000000
min_candidate = None
for candidate in self.candidates_pool:
dis = abs(aspect_ratio - candidate.aspect_ratio)
if dis < min_distance:
min_distance = dis
min_candidate = candidate
return min_candidate
def __call__(self, batch_data):
self._prefill_count -= 1
if isinstance(batch_data[0], torch.Tensor):
batch_data[0] = batch_data[0].unbind(0)
batch_data = list(zip(*batch_data))
for data in batch_data:
candidate = self.find_candidate(data)
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
candidate.add_sample(future)
if self._prefill_count >= 0:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
batch_size = len(batch_data)
random.shuffle(self.candidates_pool)
for candidate in self.candidates_pool:
if candidate.ready(batch_size=batch_size):
return candidate.get_batch(batch_size=batch_size)
# fallback to default 256
for data in batch_data:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
return self.default_candidate.get_batch(batch_size=batch_size)

View File

View File

@@ -0,0 +1,60 @@
import torch
def simple_guidance_fn(out, cfg):
uncondition, condtion = out.chunk(2, dim=0)
out = uncondition + cfg * (condtion - uncondition)
return out
def c3_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condtion = out.chunk(2, dim=0)
out = condtion
out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
return out
def c4_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condition = out.chunk(2, dim=0)
out = condition
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
return out
def c4_p05_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condition = out.chunk(2, dim=0)
out = condition
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
return out
def c4_p10_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condition = out.chunk(2, dim=0)
out = condition
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
return out
def c4_p15_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condition = out.chunk(2, dim=0)
out = condition
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
return out
def c4_p20_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condition = out.chunk(2, dim=0)
out = condition
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
return out
def p4_guidance_fn(out, cfg):
# guidance function in DiT/SiT, seems like a bug not a feature?
uncondition, condtion = out.chunk(2, dim=0)
out = condtion
out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
return out

View File

@@ -0,0 +1,31 @@
from typing import Union, List
import torch
import torch.nn as nn
from typing import Callable
from src.diffusion.base.scheduling import BaseScheduler
class BaseSampler(nn.Module):
def __init__(self,
scheduler: BaseScheduler = None,
guidance_fn: Callable = None,
num_steps: int = 250,
guidance: Union[float, List[float]] = 1.0,
*args,
**kwargs
):
super(BaseSampler, self).__init__()
self.num_steps = num_steps
self.guidance = guidance
self.guidance_fn = guidance_fn
self.scheduler = scheduler
def _impl_sampling(self, net, noise, condition, uncondition):
raise NotImplementedError
def __call__(self, net, noise, condition, uncondition):
denoised = self._impl_sampling(net, noise, condition, uncondition)
return denoised

View File

@@ -0,0 +1,32 @@
import torch
from torch import Tensor
class BaseScheduler:
def alpha(self, t) -> Tensor:
...
def sigma(self, t) -> Tensor:
...
def dalpha(self, t) -> Tensor:
...
def dsigma(self, t) -> Tensor:
...
def dalpha_over_alpha(self, t) -> Tensor:
return self.dalpha(t) / self.alpha(t)
def dsigma_mul_sigma(self, t) -> Tensor:
return self.dsigma(t)*self.sigma(t)
def drift_coefficient(self, t):
alpha, sigma = self.alpha(t), self.sigma(t)
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
return dalpha/(alpha + 1e-6)
def diffuse_coefficient(self, t):
alpha, sigma = self.alpha(t), self.sigma(t)
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
def w(self, t):
return self.sigma(t)

View File

@@ -0,0 +1,29 @@
import time
import torch
import torch.nn as nn
class BaseTrainer(nn.Module):
def __init__(self,
null_condition_p=0.1,
log_var=False,
):
super(BaseTrainer, self).__init__()
self.null_condition_p = null_condition_p
self.log_var = log_var
def preproprocess(self, raw_iamges, x, condition, uncondition):
bsz = x.shape[0]
if self.null_condition_p > 0:
mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
mask = mask.expand_as(condition)
condition[mask] = uncondition[mask]
return raw_iamges, x, condition
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
raise NotImplementedError
def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
return self._impl_trainstep(net, ema_net, raw_images, x, condition)

View File

@@ -0,0 +1,40 @@
import torch
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
import logging
logger = logging.getLogger(__name__)
class DDIMSampler(BaseSampler):
def __init__(
self,
train_num_steps=1000,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.train_num_steps = train_num_steps
assert self.scheduler is not None
def _impl_sampling(self, net, noise, condition, uncondition):
batch_size = noise.shape[0]
steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
steps = torch.flip(steps, dims=[0])
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = x0 = noise
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
t_cur = t_cur.repeat(batch_size)
t_next = t_next.repeat(batch_size)
sigma = self.scheduler.sigma(t_cur)
alpha = self.scheduler.alpha(t_cur)
sigma_next = self.scheduler.sigma(t_next)
alpha_next = self.scheduler.alpha(t_next)
cfg_x = torch.cat([x, x], dim=0)
t = t_cur.repeat(2)
out = net(cfg_x, t, cfg_condition)
out = self.guidance_fn(out, self.guidance)
x0 = (x - sigma * out) / alpha
x = alpha_next * x0 + sigma_next * out
return x0

View File

@@ -0,0 +1,102 @@
import math
import torch
from src.diffusion.base.scheduling import *
class DDPMScheduler(BaseScheduler):
def __init__(
self,
beta_min=0.0001,
beta_max=0.02,
num_steps=1000,
):
super().__init__()
self.beta_min = beta_min
self.beta_max = beta_max
self.num_steps = num_steps
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
self.sigmas_table = 1-self.alphas_table
def beta(self, t) -> Tensor:
t = t.to(torch.long)
return self.betas_table[t].view(-1, 1, 1, 1)
def alpha(self, t) -> Tensor:
t = t.to(torch.long)
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
def sigma(self, t) -> Tensor:
t = t.to(torch.long)
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
def dsigma(self, t) -> Tensor:
raise NotImplementedError("wrong usage")
def dalpha_over_alpha(self, t) ->Tensor:
raise NotImplementedError("wrong usage")
def dsigma_mul_sigma(self, t) ->Tensor:
raise NotImplementedError("wrong usage")
def dalpha(self, t) -> Tensor:
raise NotImplementedError("wrong usage")
def drift_coefficient(self, t):
raise NotImplementedError("wrong usage")
def diffuse_coefficient(self, t):
raise NotImplementedError("wrong usage")
def w(self, t):
raise NotImplementedError("wrong usage")
class VPScheduler(BaseScheduler):
def __init__(
self,
beta_min=0.1,
beta_max=20,
):
super().__init__()
self.beta_min = beta_min
self.beta_d = beta_max - beta_min
def beta(self, t) -> Tensor:
t = torch.clamp(t, min=1e-3, max=1)
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
def sigma(self, t) -> Tensor:
t = torch.clamp(t, min=1e-3, max=1)
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
def dsigma(self, t) -> Tensor:
raise NotImplementedError("wrong usage")
def dalpha_over_alpha(self, t) ->Tensor:
raise NotImplementedError("wrong usage")
def dsigma_mul_sigma(self, t) ->Tensor:
raise NotImplementedError("wrong usage")
def dalpha(self, t) -> Tensor:
raise NotImplementedError("wrong usage")
def alpha(self, t) -> Tensor:
t = torch.clamp(t, min=1e-3, max=1)
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
def drift_coefficient(self, t):
raise NotImplementedError("wrong usage")
def diffuse_coefficient(self, t):
raise NotImplementedError("wrong usage")
def w(self, t):
return self.diffuse_coefficient(t)

View File

@@ -0,0 +1,83 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class VPTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
train_max_t=1000,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.train_max_t = train_max_t
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
sigma = self.scheduler.sigma(t)
x_t = alpha * x + noise * sigma
out = net(x_t, t*self.train_max_t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - noise)**2
out = dict(
loss=loss.mean(),
)
return out
class DDPMTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn: Callable = constant,
train_max_t=1000,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.train_max_t = train_max_t
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
t = torch.randint(0, self.train_max_t, (batch_size,))
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
sigma = self.scheduler.sigma(t)
x_t = alpha * x + noise * sigma
out = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight * (out - noise) ** 2
out = dict(
loss=loss.mean(),
)
return out

View File

@@ -0,0 +1,59 @@
import torch
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def ode_step_fn(x, eps, beta, sigma, dt):
return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
def sde_step_fn(x, eps, beta, sigma, dt):
return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
import logging
logger = logging.getLogger(__name__)
class VPEulerSampler(BaseSampler):
def __init__(
self,
train_max_t=1000,
guidance_fn: Callable = None,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.guidance_fn = guidance_fn
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.train_max_t = train_max_t
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
assert self.last_step > 0.0
assert self.scheduler is not None
def _impl_sampling(self, net, noise, condition, uncondition):
batch_size = noise.shape[0]
steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
sigma = self.scheduler.sigma(t_cur)
beta = self.scheduler.beta(t_cur)
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
eps = self.guidance_fn(out, self.guidance)
if i < self.num_steps -1 :
x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
x = self.step_fn(x, eps, beta, sigma, dt)
else:
x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
return x

View File

@@ -0,0 +1,107 @@
import math
from src.diffusion.base.sampling import *
from src.diffusion.base.scheduling import *
from src.diffusion.pre_integral import *
from typing import Callable, List, Tuple
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
def t2snr(t):
if isinstance(t, torch.Tensor):
return (t.clip(min=1e-8)/(1-t + 1e-8))
if isinstance(t, List) or isinstance(t, Tuple):
return [t2snr(t) for t in t]
t = max(t, 1e-8)
return (t/(1-t + 1e-8))
def t2logsnr(t):
if isinstance(t, torch.Tensor):
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
if isinstance(t, List) or isinstance(t, Tuple):
return [t2logsnr(t) for t in t]
t = max(t, 1e-3)
return math.log(t/(1-t + 1e-3))
def t2isnr(t):
return 1/t2snr(t)
def nop(t):
return t
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
import logging
logger = logging.getLogger(__name__)
class AdamLMSampler(BaseSampler):
def __init__(
self,
order: int = 2,
timeshift: float = 1.0,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, timeshift)
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
self._reparameterize_coeffs()
def _reparameterize_coeffs(self):
solver_coeffs = [[] for _ in range(self.num_steps)]
for i in range(0, self.num_steps):
pre_vs = [1.0, ]*(i+1)
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
int_t_start = self.lms_transform_fn(self.timesteps[i])
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
order_annealing = self.order #self.num_steps - i
order = min(self.order, i + 1, order_annealing)
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
solver_coeffs[i] = coeffs
self.solver_coeffs = solver_coeffs
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = x0 = noise
pred_trajectory = []
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
timedeltas = self.timedeltas
solver_coeffs = self.solver_coeffs
for i in range(self.num_steps):
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t, cfg_condition)
out = self.guidance_fn(out, self.guidances[i])
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])
for j in range(order):
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
v = out
dt = timedeltas[i]
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
x = self.step_fn(x, v, dt, s=0, w=0)
t_cur += dt
return x

View File

@@ -0,0 +1,179 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
def sde_mean_step_fn(x, v, dt, s, w):
return x + v * dt + s * w * dt
def sde_step_fn(x, v, dt, s, w):
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
def sde_preserve_step_fn(x, v, dt, s, w):
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
import logging
logger = logging.getLogger(__name__)
class EulerSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
sigma = self.scheduler.sigma(t_cur)
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
if self.w_scheduler:
w = self.w_scheduler.w(t_cur)
else:
w = 0.0
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
out = net(cfg_x, cfg_t, cfg_condition)
out = self.guidance_fn(out, self.guidance)
v = out
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
if i < self.num_steps -1 :
x = self.step_fn(x, v, dt, s=s, w=w)
else:
x = self.last_step_fn(x, v, dt, s=s, w=w)
return x
class HeunSampler(BaseSampler):
def __init__(
self,
scheduler: BaseScheduler = None,
w_scheduler: BaseScheduler = None,
exact_henu=False,
timeshift=1.0,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.scheduler = scheduler
self.exact_henu = exact_henu
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Henu sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
v_hat, s_hat = 0.0, 0.0
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
sigma = self.scheduler.sigma(t_cur)
alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
t_hat = t_next
t_hat = t_hat.repeat(batch_size)
sigma_hat = self.scheduler.sigma(t_hat)
alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
if self.w_scheduler:
w = self.w_scheduler.w(t_cur)
else:
w = 0.0
if i == 0 or self.exact_henu:
cfg_x = torch.cat([x, x], dim=0)
cfg_t_cur = t_cur.repeat(2)
out = net(cfg_x, cfg_t_cur, cfg_condition)
out = self.guidance_fn(out, self.guidance)
v = out
s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
else:
v = v_hat
s = s_hat
x_hat = self.step_fn(x, v, dt, s=s, w=w)
# henu correct
if i < self.num_steps -1:
cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
cfg_t_hat = t_hat.repeat(2)
out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
out = self.guidance_fn(out, self.guidance)
v_hat = out
s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
v = (v + v_hat) / 2
s = (s + s_hat) / 2
x = self.step_fn(x, v, dt, s=s, w=w)
else:
x = self.last_step_fn(x, v, dt, s=s, w=w)
return x

View File

@@ -0,0 +1,39 @@
import math
import torch
from src.diffusion.base.scheduling import *
class LinearScheduler(BaseScheduler):
def alpha(self, t) -> Tensor:
return (t).view(-1, 1, 1, 1)
def sigma(self, t) -> Tensor:
return (1-t).view(-1, 1, 1, 1)
def dalpha(self, t) -> Tensor:
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
def dsigma(self, t) -> Tensor:
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
# SoTA for ImageNet!
class GVPScheduler(BaseScheduler):
def alpha(self, t) -> Tensor:
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
def sigma(self, t) -> Tensor:
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
def dalpha(self, t) -> Tensor:
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
def dsigma(self, t) -> Tensor:
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
def w(self, t):
return torch.sin(t)**2
class ConstScheduler(BaseScheduler):
def w(self, t):
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
from src.diffusion.ddpm.scheduling import VPScheduler
class VPBetaScheduler(VPScheduler):
def w(self, t):
return self.beta(t).view(-1, 1, 1, 1)

View File

@@ -0,0 +1,55 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class FlowMatchingTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
out = dict(
loss=loss.mean(),
)
return out

View File

@@ -0,0 +1,59 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class COSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
cos_loss = 1 - cos_sim
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + cos_loss.mean(),
)
return out

View File

@@ -0,0 +1,68 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class PyramidTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
output_pyramid = []
def feature_hook(module, input, output):
output_pyramid.extend(output)
handle = net.decoder.register_forward_hook(feature_hook)
net(x_t, t, y)
handle.remove()
loss = 0.0
out_dict = dict()
cur_v_t = v_t
for i in range(len(output_pyramid)):
cur_out = output_pyramid[i]
loss_i = (cur_v_t - cur_out) ** 2
loss += loss_i.mean()
out_dict["loss_{}".format(i)] = loss_i.mean()
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
out_dict["loss"] = loss
return out_dict

View File

@@ -0,0 +1,142 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,152 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
mask_ratio=0.0,
mask_patch_size=2,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.mask_ratio = mask_ratio
self.mask_patch_size = mask_patch_size
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
patch_mask = (patch_mask < self.mask_ratio).float()
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
x_t = alpha*masked_x + sigma*noise
v_t = dalpha*x + dsigma*noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
v_t_out, x0_out = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
mask_loss=mask_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,143 @@
import torch
# lagrange interpolation
def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
'''
lagrange interpolation of order 1
Args:
t1: timestepx
v1: value field at t1
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
int1 = (int_t_end-int_t_start)
return int1*v1, (int1/int1, )
def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
'''
lagrange interpolation of order 2
Args:
t1: timestepx
t2: timestepy
v1: value field at t1
v2: value field at t2
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
int_sum = int1+int2
return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
'''
lagrange interpolation of order 3
Args:
t1: timestepx
t2: timestepy
t3: timestepz
v1: value field at t1
v2: value field at t2
v3: value field at t3
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
int1_denom = (t1-t2)*(t1-t3)
int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
int1 = (int1_end - int1_start)/int1_denom
int2_denom = (t2-t1)*(t2-t3)
int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
int2 = (int2_end - int2_start)/int2_denom
int3_denom = (t3-t1)*(t3-t2)
int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
int3 = (int3_end - int3_start)/int3_denom
int_sum = int1+int2+int3
return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
'''
lagrange interpolation of order 4
Args:
t1: timestepx
t2: timestepy
t3: timestepz
t4: timestepw
v1: value field at t1
v2: value field at t2
v3: value field at t3
v4: value field at t4
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
int1 = (int1_end - int1_start)/int1_denom
int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
int2 = (int2_end - int2_start)/int2_denom
int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
int3 = (int3_end - int3_start)/int3_denom
int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
int4 = (int4_end - int4_start)/int4_denom
int_sum = int1+int2+int3+int4
return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
'''
lagrange interpolation
Args:
order: order of interpolation
pre_vs: value field at pre_ts
pre_ts: timesteps
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
order = min(order, len(pre_vs), len(pre_ts))
if order == 1:
return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
elif order == 2:
return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
elif order == 3:
return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
elif order == 4:
return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
else:
raise ValueError('Invalid order')
def polynomial_integral(coeffs, int_t_start, int_t_end):
'''
polynomial integral
Args:
coeffs: coefficients of the polynomial
int_t_start: intergation start time
int_t_end: intergation end time
Returns:
integrated value
'''
orders = len(coeffs)
int_val = 0
for o in range(orders):
int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
return int_val

View File

@@ -0,0 +1,112 @@
import math
from src.diffusion.base.sampling import *
from src.diffusion.base.scheduling import *
from src.diffusion.pre_integral import *
from typing import Callable, List, Tuple
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
def t2snr(t):
if isinstance(t, torch.Tensor):
return (t.clip(min=1e-8)/(1-t + 1e-8))
if isinstance(t, List) or isinstance(t, Tuple):
return [t2snr(t) for t in t]
t = max(t, 1e-8)
return (t/(1-t + 1e-8))
def t2logsnr(t):
if isinstance(t, torch.Tensor):
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
if isinstance(t, List) or isinstance(t, Tuple):
return [t2logsnr(t) for t in t]
t = max(t, 1e-3)
return math.log(t/(1-t + 1e-3))
def t2isnr(t):
return 1/t2snr(t)
def nop(t):
return t
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
import logging
logger = logging.getLogger(__name__)
class AdamLMSampler(BaseSampler):
def __init__(
self,
order: int = 2,
timeshift: float = 1.0,
state_refresh_rate: int = 1,
lms_transform_fn: Callable = nop,
w_scheduler: BaseScheduler = None,
step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.w_scheduler = w_scheduler
self.state_refresh_rate = state_refresh_rate
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
self.order = order
self.lms_transform_fn = lms_transform_fn
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, timeshift)
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
self._reparameterize_coeffs()
def _reparameterize_coeffs(self):
solver_coeffs = [[] for _ in range(self.num_steps)]
for i in range(0, self.num_steps):
pre_vs = [1.0, ]*(i+1)
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
int_t_start = self.lms_transform_fn(self.timesteps[i])
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
order_annealing = self.order #self.num_steps - i
order = min(self.order, i + 1, order_annealing)
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
solver_coeffs[i] = coeffs
self.solver_coeffs = solver_coeffs
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = x0 = noise
state = None
pred_trajectory = []
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
timedeltas = self.timedeltas
solver_coeffs = self.solver_coeffs
for i in range(self.num_steps):
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
out = self.guidance_fn(out, self.guidances[i])
pred_trajectory.append(out)
out = torch.zeros_like(out)
order = len(self.solver_coeffs[i])
for j in range(order):
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
v = out
dt = timedeltas[i]
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
x = self.step_fn(x, v, dt, s=0, w=0)
t_cur += dt
return x

View File

@@ -0,0 +1,122 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -0,0 +1,127 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
lpips_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.lpips_weight = lpips_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
real_score_gan = self.dis_head(real_features[-1].detach())
fake_score_gan = self.dis_head(fake_features[-1].detach())
fake_score = self.dis_head(fake_features[-1])
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -0,0 +1,159 @@
import random
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class MaskREPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
mask_groups=4,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.mask_groups = mask_groups
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
random_seq = torch.randperm(length, device=device)
for i in range(groups):
group_start = (length+groups-1)//groups*i
group_end = (length+groups-1)//groups*(i+1)
group_random_seq = random_seq[group_start:group_end]
y, x = torch.meshgrid(group_random_seq, group_random_seq)
mask[:, y, x] = True
return mask
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
mask_groups = random.randint(1, self.mask_groups)
mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
out, _ = net(x_t, t, y, mask=mask)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,179 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t, mul=1000):
t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class BatchNormWithTimeEmbedding(nn.Module):
def __init__(self, num_features):
super().__init__()
# self.bn = nn.BatchNorm2d(num_features, affine=False)
self.bn = nn.GroupNorm(16, num_features, affine=False)
# self.bn = nn.SyncBatchNorm(num_features, affine=False)
self.embedder = TimestepEmbedder(num_features * 2)
# nn.init.zeros_(self.embedder.mlp[-1].weight)
nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
nn.init.zeros_(self.embedder.mlp[-1].bias)
def forward(self, x, t):
embed = self.embedder(t)
embed = embed[:, :, None, None]
gamma, beta = embed.chunk(2, dim=1)
gamma = 1.0 + gamma
normed = self.bn(x)
out = normed * gamma + beta
return out
class DisBlock(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.conv = nn.Conv2d(
kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
)
self.norm = BatchNormWithTimeEmbedding(hidden_size)
self.act = nn.SiLU()
def forward(self, x, t):
x = self.conv(x)
x = self.norm(x, t)
x = self.act(x)
return x
class Discriminator(nn.Module):
def __init__(self, num_blocks, in_channels, hidden_size):
super().__init__()
self.blocks = nn.ModuleList()
for i in range(num_blocks):
self.blocks.append(
DisBlock(
in_channels=in_channels,
hidden_size=hidden_size,
)
)
in_channels = hidden_size
self.classifier = nn.Conv2d(
kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
)
def forward(self, feature, t):
B, C, H, W = feature.shape
for block in self.blocks:
feature = block(feature, t)
out = self.classifier(feature).view(B, -1)
out = out.sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_blocks=3,
adv_in_channels=3,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.discriminator = Discriminator(
num_blocks=adv_blocks,
in_channels=adv_in_channels*2,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = x_t + sigma * out
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
real_feature = torch.cat([x_t, x], dim=1)
fake_feature = torch.cat([x_t, pred_x0], dim=1)
real_score_gan = self.discriminator(real_feature.detach(), t)
fake_score_gan = self.discriminator(fake_feature.detach(), t)
fake_score = self.discriminator(fake_feature, t)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -0,0 +1,154 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPAJiTTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
jit_deltas=0.01,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.jit_deltas = jit_deltas
self.encoder = DINOv2(encoder_weight_path)
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
t2 = torch.clip(t2, 0, 1)
alpha = self.scheduler.alpha(t2)
dalpha = self.scheduler.dalpha(t2)
sigma = self.scheduler.sigma(t2)
dsigma = self.scheduler.dsigma(t2)
x_t2 = alpha * x + noise * sigma
v_t2 = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
_, s = net(x_t, t, y, only_s=True)
out, _ = net(x_t2, t2, y, s)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t2)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,90 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfConsistentTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
real_features = []
def forward_hook(net, input, output):
real_features.append(output)
handles = []
for i in range(self.lpips_encoder_layer):
handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
handles.append(handle)
out, _ = net(x_t, t, y)
for handle in handles:
handle.remove()
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
_, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -0,0 +1,81 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class SelfLPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
lpips_encoder_layer=4,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_encoder_layer = lpips_encoder_layer
self.lpips_weight = lpips_weight
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + noise * sigma
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
with torch.no_grad():
_, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
lpips_loss = []
for r, f in zip(real_features, fake_features):
r = torch.nn.functional.normalize(r, dim=-1)
f = torch.nn.functional.normalize(f, dim=-1)
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
lpips_loss = sum(lpips_loss)
out = dict(
lpips_loss=lpips_loss.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
)
return out

View File

@@ -0,0 +1,78 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class CMSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
last_step=None,
step_fn=None,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.last_step = last_step
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
cfg_t = t_cur.repeat(batch_size*2)
cfg_x = torch.cat([x, x], dim=0)
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
x0 = x + v * (1-t_cur)
alpha_next = self.scheduler.alpha(t_next)
sigma_next = self.scheduler.sigma(t_next)
x = alpha_next * x0 + sigma_next * torch.randn_like(x)
# print(alpha_next, sigma_next)
return x

View File

@@ -0,0 +1,103 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
def sde_mean_step_fn(x, v, dt, s, w):
return x + v * dt + s * w * dt
def sde_step_fn(x, v, dt, s, w):
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
def sde_preserve_step_fn(x, v, dt, s, w):
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
import logging
logger = logging.getLogger(__name__)
class EulerSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
sigma = self.scheduler.sigma(t_cur)
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
if self.w_scheduler:
w = self.w_scheduler.w(t_cur)
else:
w = 0.0
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
if i % self.state_refresh_rate == 0:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
if i < self.num_steps -1 :
x = self.step_fn(x, v, dt, s=s, w=w)
else:
x = self.last_step_fn(x, v, dt, s=s, w=w)
return x

View File

@@ -0,0 +1,39 @@
import math
import torch
from src.diffusion.base.scheduling import *
class LinearScheduler(BaseScheduler):
def alpha(self, t) -> Tensor:
return (t).view(-1, 1, 1, 1)
def sigma(self, t) -> Tensor:
return (1-t).view(-1, 1, 1, 1)
def dalpha(self, t) -> Tensor:
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
def dsigma(self, t) -> Tensor:
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
# SoTA for ImageNet!
class GVPScheduler(BaseScheduler):
def alpha(self, t) -> Tensor:
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
def sigma(self, t) -> Tensor:
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
def dalpha(self, t) -> Tensor:
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
def dsigma(self, t) -> Tensor:
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
def w(self, t):
return torch.sin(t)**2
class ConstScheduler(BaseScheduler):
def w(self, t):
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
from src.diffusion.ddpm.scheduling import VPScheduler
class VPBetaScheduler(VPScheduler):
def w(self, t):
return self.beta(t).view(-1, 1, 1, 1)

View File

@@ -0,0 +1,149 @@
import torch
from src.diffusion.base.guidance import *
from src.diffusion.base.scheduling import *
from src.diffusion.base.sampling import *
from typing import Callable
def shift_respace_fn(t, shift=3.0):
return t / (t + (1 - t) * shift)
def ode_step_fn(x, v, dt, s, w):
return x + v * dt
import logging
logger = logging.getLogger(__name__)
class EulerSampler(BaseSampler):
def __init__(
self,
w_scheduler: BaseScheduler = None,
timeshift=1.0,
guidance_interval_min: float = 0.0,
guidance_interval_max: float = 1.0,
state_refresh_rate=1,
step_fn: Callable = ode_step_fn,
last_step=None,
last_step_fn: Callable = ode_step_fn,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.step_fn = step_fn
self.last_step = last_step
self.last_step_fn = last_step_fn
self.w_scheduler = w_scheduler
self.timeshift = timeshift
self.state_refresh_rate = state_refresh_rate
self.guidance_interval_min = guidance_interval_min
self.guidance_interval_max = guidance_interval_max
if self.last_step is None or self.num_steps == 1:
self.last_step = 1.0 / self.num_steps
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
assert self.last_step > 0.0
assert self.scheduler is not None
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
if self.w_scheduler is not None:
if self.step_fn == ode_step_fn:
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
# init recompute
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
self.recompute_timesteps = list(range(self.num_steps))
def sharing_dp(self, net, noise, condition, uncondition):
_, C, H, W = noise.shape
B = 8
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
template_uncondition = torch.full((B, ), 1000, device=condition.device)
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
states = torch.stack(state_list)
N, B, L, C = states.shape
states = states.view(N, B*L, C )
states = states.permute(1, 0, 2)
states = torch.nn.functional.normalize(states, dim=-1)
with torch.autocast(device_type="cuda", dtype=torch.float64):
sim = torch.bmm(states, states.transpose(1, 2))
sim = torch.mean(sim, dim=0).cpu()
error_map = (1-sim).tolist()
# init cum-error
for i in range(1, self.num_steps):
for j in range(0, i):
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
# init dp and force 0 start
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
for i in range(1, self.num_steps+1):
C[1][i] = error_map[i - 1][0]
P[1][i] = 0
# dp state
for step in range(2, self.num_recompute_timesteps+1):
for i in range(step, self.num_steps+1):
min_value = 99999
min_index = -1
for j in range(step-1, i):
value = C[step-1][j] + error_map[i-1][j]
if value < min_value:
min_value = value
min_index = j
C[step][i] = min_value
P[step][i] = min_index
# trace back
timesteps = [self.num_steps,]
for i in range(self.num_recompute_timesteps, 0, -1):
idx = timesteps[-1]
timesteps.append(P[i][idx])
timesteps.reverse()
print("recompute timesteps solved by DP: ", timesteps)
return timesteps[:-1]
def _impl_sampling(self, net, noise, condition, uncondition):
"""
sampling process of Euler sampler
-
"""
batch_size = noise.shape[0]
steps = self.timesteps.to(noise.device)
cfg_condition = torch.cat([uncondition, condition], dim=0)
x = noise
state = None
pooled_state_list = []
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
dt = t_next - t_cur
t_cur = t_cur.repeat(batch_size)
cfg_x = torch.cat([x, x], dim=0)
cfg_t = t_cur.repeat(2)
if i in self.recompute_timesteps:
state = None
out, state = net(cfg_x, cfg_t, cfg_condition, state)
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
out = self.guidance_fn(out, self.guidance)
else:
out = self.guidance_fn(out, 1.0)
v = out
if i < self.num_steps -1 :
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
else:
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
pooled_state_list.append(state)
return x, pooled_state_list
def __call__(self, net, noise, condition, uncondition):
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
return denoised

View File

@@ -0,0 +1,55 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class FlowMatchingTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
out = dict(
loss=loss.mean(),
)
return out

View File

@@ -0,0 +1,122 @@
import torch
import math
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class Discriminator(nn.Module):
def __init__(self, in_channels, hidden_size):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
nn.SiLU(),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
)
def forward(self, feature):
B, L, C = feature.shape
H = W = int(math.sqrt(L))
feature = feature.permute(0, 2, 1)
feature = feature.view(B, C, H, W)
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
return out
class AdvTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
adv_weight=1.0,
adv_encoder_layer=4,
adv_in_channels=768,
adv_hidden_size=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.adv_weight = adv_weight
self.adv_encoder_layer = adv_encoder_layer
self.dis_head = Discriminator(
in_channels=adv_in_channels,
hidden_size=adv_hidden_size,
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
adv_feature = []
def forward_hook(net, input, output):
adv_feature.append(output)
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
real_feature = adv_feature.pop()
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
fake_feature = adv_feature.pop()
handle.remove()
real_score_gan = self.dis_head(real_feature.detach())
fake_score_gan = self.dis_head(fake_feature.detach())
fake_score = self.dis_head(fake_feature)
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
acc_real = (real_score_gan > 0.5).float()
acc_fake = (fake_score_gan < 0.5).float()
loss_adv = -torch.log(fake_score)
loss_adv_hack = torch.log(fake_score_gan)
out = dict(
adv_loss=loss_adv.mean(),
gan_loss=loss_gan.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
acc_real=acc_real.mean(),
acc_fake=acc_fake.mean(),
)
return out

View File

@@ -0,0 +1,141 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class DistillDINOTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
sigma = self.scheduler.sigma(t)
x_t = alpha * x + noise * sigma
_, s = net(x_t, t, y)
src_feature = self.proj(s)
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
out = dict(
cos_loss=cos_loss.mean(),
loss=cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,71 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
weight = self.loss_weight_fn(alpha, sigma)
loss = weight*(out - v_t)**2
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -0,0 +1,74 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from src.utils.no_grad import no_grad
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class LPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
lognorm_t=False,
lpips_weight=1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = False
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
# self.lpips = torch.compile(self.lpips)
no_grad(self.lpips)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size = x.shape[0]
if self.lognorm_t:
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
else:
t = torch.rand(batch_size).to(x.device, x.dtype)
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
w = self.scheduler.w(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
out, _ = net(x_t, t, y)
fm_weight = t*(1-t)**2/0.25
lpips_weight = t
loss = (out - v_t)**2 * fm_weight[:, None, None, None]
pred_x0 = (x_t + out*sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
out = dict(
lpips_loss=lpips.mean(),
fm_loss=loss.mean(),
loss=loss.mean() + lpips.mean()*self.lpips_weight,
)
return out
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
return

View File

@@ -0,0 +1,157 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPATrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
if getattr(net, "blocks", None) is not None:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
out = dict(
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

View File

@@ -0,0 +1,170 @@
import torch
import copy
import timm
from torch.nn import Parameter
from src.utils.no_grad import no_grad
from typing import Callable, Iterator, Tuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
from torchmetrics.image.lpip import _NoTrainLpips
def inverse_sigma(alpha, sigma):
return 1/sigma**2
def snr(alpha, sigma):
return alpha/sigma
def minsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, min=threshold)
def maxsnr(alpha, sigma, threshold=5):
return torch.clip(alpha/sigma, max=threshold)
def constant(alpha, sigma):
return 1
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load(
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
weight_path,
source="local",
skip_validation=True
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
return feature
class REPALPIPSTrainer(BaseTrainer):
def __init__(
self,
scheduler: BaseScheduler,
loss_weight_fn:Callable=constant,
feat_loss_weight: float=0.5,
lognorm_t=False,
lpips_weight=1.0,
encoder_weight_path=None,
align_layer=8,
proj_denoiser_dim=256,
proj_hidden_dim=256,
proj_encoder_dim=256,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.lognorm_t = lognorm_t
self.scheduler = scheduler
self.loss_weight_fn = loss_weight_fn
self.feat_loss_weight = feat_loss_weight
self.align_layer = align_layer
self.encoder = DINOv2(encoder_weight_path)
self.proj_encoder_dim = proj_encoder_dim
no_grad(self.encoder)
self.lpips_weight = lpips_weight
self.lpips = _NoTrainLpips(net="vgg")
self.lpips = self.lpips.to(torch.bfloat16)
no_grad(self.lpips)
self.proj = nn.Sequential(
nn.Sequential(
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_hidden_dim),
nn.SiLU(),
nn.Linear(proj_hidden_dim, proj_encoder_dim),
)
)
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
batch_size, c, height, width = x.shape
if self.lognorm_t:
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
else:
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
t = base_t
noise = torch.randn_like(x)
alpha = self.scheduler.alpha(t)
dalpha = self.scheduler.dalpha(t)
sigma = self.scheduler.sigma(t)
dsigma = self.scheduler.dsigma(t)
x_t = alpha * x + noise * sigma
v_t = dalpha * x + dsigma * noise
src_feature = []
def forward_hook(net, input, output):
src_feature.append(output)
if getattr(net, "blocks", None) is not None:
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
else:
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
out, _ = net(x_t, t, y)
src_feature = self.proj(src_feature[0])
handle.remove()
with torch.no_grad():
dst_feature = self.encoder(raw_images)
if dst_feature.shape[1] != src_feature.shape[1]:
dst_length = dst_feature.shape[1]
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
dst_height = (dst_length)**0.5 * (height/width)**0.5
dst_width = (dst_length)**0.5 * (width/height)**0.5
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
dst_feature = dst_feature.permute(0, 3, 1, 2)
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
dst_feature = dst_feature.permute(0, 2, 3, 1)
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
cos_loss = 1 - cos_sim
weight = self.loss_weight_fn(alpha, sigma)
fm_loss = weight*(out - v_t)**2
pred_x0 = (x_t + out * sigma)
target_x0 = x
# fixbug lpips std
lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
out = dict(
lpips_loss=lpips.mean(),
fm_loss=fm_loss.mean(),
cos_loss=cos_loss.mean(),
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
)
return out
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
self.proj.state_dict(
destination=destination,
prefix=prefix + "proj.",
keep_vars=keep_vars)

162
src/lightning_data.py Normal file
View File

@@ -0,0 +1,162 @@
from typing import Any
import torch
import copy
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import DataLoader
from src.data.dataset.randn import RandomNDataset
from src.data.var_training import VARTransformEngine
def collate_fn(batch):
new_batch = copy.deepcopy(batch)
new_batch = list(zip(*new_batch))
for i in range(len(new_batch)):
if isinstance(new_batch[i][0], torch.Tensor):
try:
new_batch[i] = torch.stack(new_batch[i], dim=0)
except:
print("Warning: could not stack tensors")
return new_batch
class DataModule(pl.LightningDataModule):
def __init__(self,
train_root,
test_nature_root,
test_gen_root,
train_image_size=64,
train_batch_size=64,
train_num_workers=8,
var_transform_engine: VARTransformEngine = None,
train_prefetch_factor=2,
train_dataset: str = None,
eval_batch_size=32,
eval_num_workers=4,
eval_max_num_instances=50000,
pred_batch_size=32,
pred_num_workers=4,
pred_seeds:str=None,
pred_selected_classes=None,
num_classes=1000,
latent_shape=(4,64,64),
):
super().__init__()
pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None
self.train_root = train_root
self.train_image_size = train_image_size
self.train_dataset = train_dataset
# stupid data_convert override, just to make nebular happy
self.train_batch_size = train_batch_size
self.train_num_workers = train_num_workers
self.train_prefetch_factor = train_prefetch_factor
self.test_nature_root = test_nature_root
self.test_gen_root = test_gen_root
self.eval_max_num_instances = eval_max_num_instances
self.pred_seeds = pred_seeds
self.num_classes = num_classes
self.latent_shape = latent_shape
self.eval_batch_size = eval_batch_size
self.pred_batch_size = pred_batch_size
self.pred_num_workers = pred_num_workers
self.eval_num_workers = eval_num_workers
self.pred_selected_classes = pred_selected_classes
self._train_dataloader = None
self.var_transform_engine = var_transform_engine
def setup(self, stage: str) -> None:
if stage == "fit":
assert self.train_dataset is not None
if self.train_dataset == "pix_imagenet64":
from src.data.dataset.imagenet import PixImageNet64
self.train_dataset = PixImageNet64(
root=self.train_root,
)
elif self.train_dataset == "pix_imagenet128":
from src.data.dataset.imagenet import PixImageNet128
self.train_dataset = PixImageNet128(
root=self.train_root,
)
elif self.train_dataset == "imagenet256":
from src.data.dataset.imagenet import ImageNet256
self.train_dataset = ImageNet256(
root=self.train_root,
)
elif self.train_dataset == "pix_imagenet256":
from src.data.dataset.imagenet import PixImageNet256
self.train_dataset = PixImageNet256(
root=self.train_root,
)
elif self.train_dataset == "imagenet512":
from src.data.dataset.imagenet import ImageNet512
self.train_dataset = ImageNet512(
root=self.train_root,
)
elif self.train_dataset == "pix_imagenet512":
from src.data.dataset.imagenet import PixImageNet512
self.train_dataset = PixImageNet512(
root=self.train_root,
)
else:
raise NotImplementedError("no such dataset")
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
if self.var_transform_engine and self.trainer.training:
batch = self.var_transform_engine(batch)
return batch
def train_dataloader(self) -> TRAIN_DATALOADERS:
global_rank = self.trainer.global_rank
world_size = self.trainer.world_size
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
self._train_dataloader = DataLoader(
self.train_dataset,
self.train_batch_size,
timeout=6000,
num_workers=self.train_num_workers,
prefetch_factor=self.train_prefetch_factor,
sampler=sampler,
collate_fn=collate_fn,
)
return self._train_dataloader
def val_dataloader(self) -> EVAL_DATALOADERS:
global_rank = self.trainer.global_rank
world_size = self.trainer.world_size
self.eval_dataset = RandomNDataset(
latent_shape=self.latent_shape,
num_classes=self.num_classes,
max_num_instances=self.eval_max_num_instances,
)
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
return DataLoader(self.eval_dataset, self.eval_batch_size,
num_workers=self.eval_num_workers,
prefetch_factor=2,
collate_fn=collate_fn,
sampler=sampler
)
def predict_dataloader(self) -> EVAL_DATALOADERS:
global_rank = self.trainer.global_rank
world_size = self.trainer.world_size
self.pred_dataset = RandomNDataset(
seeds= self.pred_seeds,
max_num_instances=50000,
num_classes=self.num_classes,
selected_classes=self.pred_selected_classes,
latent_shape=self.latent_shape,
)
from torch.utils.data import DistributedSampler
sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size,
num_workers=self.pred_num_workers,
prefetch_factor=4,
collate_fn=collate_fn,
sampler=sampler
)

123
src/lightning_model.py Normal file
View File

@@ -0,0 +1,123 @@
from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
import os.path
import copy
import torch
import torch.nn as nn
import lightning.pytorch as pl
from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from lightning.pytorch.callbacks import Callback
from src.models.vae import BaseVAE, fp2uint8
from src.models.conditioner import BaseConditioner
from src.utils.model_loader import ModelLoader
from src.callbacks.simple_ema import SimpleEMA
from src.diffusion.base.sampling import BaseSampler
from src.diffusion.base.training import BaseTrainer
from src.utils.no_grad import no_grad, filter_nograd_tensors
from src.utils.copy import copy_params
EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
OptimizerCallable = Callable[[Iterable], Optimizer]
LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
class LightningModel(pl.LightningModule):
def __init__(self,
vae: BaseVAE,
conditioner: BaseConditioner,
denoiser: nn.Module,
diffusion_trainer: BaseTrainer,
diffusion_sampler: BaseSampler,
ema_tracker: Optional[EMACallable] = None,
optimizer: OptimizerCallable = None,
lr_scheduler: LRSchedulerCallable = None,
):
super().__init__()
self.vae = vae
self.conditioner = conditioner
self.denoiser = denoiser
self.ema_denoiser = copy.deepcopy(self.denoiser)
self.diffusion_sampler = diffusion_sampler
self.diffusion_trainer = diffusion_trainer
self.ema_tracker = ema_tracker
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# self.model_loader = ModelLoader()
self._strict_loading = False
def configure_model(self) -> None:
self.trainer.strategy.barrier()
# self.denoiser = self.model_loader.load(self.denoiser)
copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
# self.denoiser = torch.compile(self.denoiser)
# disable grad for conditioner and vae
no_grad(self.conditioner)
no_grad(self.vae)
no_grad(self.diffusion_sampler)
no_grad(self.ema_denoiser)
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
return [ema_tracker]
def configure_optimizers(self) -> OptimizerLRScheduler:
params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
if self.lr_scheduler is None:
return dict(
optimizer=optimizer
)
else:
lr_scheduler = self.lr_scheduler(optimizer)
return dict(
optimizer=optimizer,
lr_scheduler=lr_scheduler
)
def training_step(self, batch, batch_idx):
raw_images, x, y = batch
with torch.no_grad():
x = self.vae.encode(x)
condition, uncondition = self.conditioner(y)
loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
return loss["loss"]
def predict_step(self, batch, batch_idx):
xT, y, metadata = batch
with torch.no_grad():
condition, uncondition = self.conditioner(y)
# Sample images:
samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
samples = self.vae.decode(samples)
# fp32 -1,1 -> uint8 0,255
samples = fp2uint8(samples)
return samples
def validation_step(self, batch, batch_idx):
samples = self.predict_step(batch, batch_idx)
return samples
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
if destination is None:
destination = {}
self._save_to_state_dict(destination, prefix, keep_vars)
self.denoiser.state_dict(
destination=destination,
prefix=prefix+"denoiser.",
keep_vars=keep_vars)
self.ema_denoiser.state_dict(
destination=destination,
prefix=prefix+"ema_denoiser.",
keep_vars=keep_vars)
self.diffusion_trainer.state_dict(
destination=destination,
prefix=prefix+"diffusion_trainer.",
keep_vars=keep_vars)
return destination

0
src/models/__init__.py Normal file
View File

26
src/models/conditioner.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class BaseConditioner(nn.Module):
def __init__(self):
super(BaseConditioner, self).__init__()
def _impl_condition(self, y):
...
def _impl_uncondition(self, y):
...
def __call__(self, y):
condition = self._impl_condition(y)
uncondition = self._impl_uncondition(y)
return condition, uncondition
class LabelConditioner(BaseConditioner):
def __init__(self, null_class):
super().__init__()
self.null_condition = null_class
def _impl_condition(self, y):
return torch.tensor(y).long().cuda()
def _impl_uncondition(self, y):
return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()

View File

View File

@@ -0,0 +1,383 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
x = torch.cat([x, s], dim=-1)
x = self.x_embedder(x)
for i in range(self.num_blocks):
x = self.blocks[i](x, c, pos, None)
x = self.final_layer(x, c)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
stride=self.patch_size)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None):
if s is None:
with torch.no_grad():
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -0,0 +1,447 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
self.embed_proj = nn.Linear(hidden_dim, dim)
def forward(self, x, c):
c = self.embed_proj(c)[:, :, None, None]
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = x * c
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
y = self.y_embedder(y).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x, c)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x, c)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -0,0 +1,448 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
self.embed_proj = nn.Linear(hidden_dim, dim)
def forward(self, x, c):
c = self.embed_proj(c)[:, :, None, None]
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = x * c
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
y = self.y_embedder(y).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
c = torch.nn.functional.silu(t + y)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x, c)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x, c)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder, "encoder.")
ModelLoader().load(decoder, "decoder.")
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -0,0 +1,464 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
class ResBlock(nn.Module):
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
self.norm1 = nn.GroupNorm(groups, dim)
self.norm2 = nn.GroupNorm(groups, dim)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = torch.nn.functional.silu(x)
x = self.conv2(x)
x = self.norm2(x)
x = torch.nn.functional.silu(x)
return residual + x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None, classify_layer=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
classify_feats = []
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
if classify_layer is not None and i < classify_layer:
classify_feats.append(s)
if i == classify_layer - 1:
return _, classify_feats
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_mid_blocks=18,
num_res_blocks=[1, 1, 1],
num_res_channels=[64, 384, 768],
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_mid_blocks = num_mid_blocks
self.num_res_blocks = num_res_blocks
self.num_res_channels = num_res_channels
self.patch_size = 2**(len(num_res_blocks))
self.t_embedder = TimestepEmbedder(hidden_size)
self.down_res_blocks = nn.ModuleList()
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.down_res_blocks.append(
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
)
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = []
previous_channel = self.in_channels
for num, channels in zip(num_res_blocks, num_res_channels):
self.up_res_blocks.append(
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
)
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
previous_channel = channels
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
self.weight_path = weight_path
self.load_ema = load_ema
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in SiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
for block in self.down_res_blocks:
if isinstance(block, ResBlock):
nn.init.constant_(block.conv1.weight, 0)
nn.init.constant_(block.conv1.bias, 0)
nn.init.constant_(block.norm1.weight, 0)
nn.init.constant_(block.norm2.weight, 0)
nn.init.constant_(block.conv2.weight, 0)
nn.init.constant_(block.conv2.bias, 0)
for block in self.up_res_blocks:
if isinstance(block, ResBlock):
nn.init.constant_(block.conv1.weight, 0)
nn.init.constant_(block.conv1.bias, 0)
nn.init.constant_(block.norm1.weight, 0)
nn.init.constant_(block.norm2.weight, 0)
nn.init.constant_(block.conv2.weight, 0)
nn.init.constant_(block.conv2.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
residual = []
for i, block in enumerate(self.down_res_blocks):
if isinstance(block, nn.Conv2d):
residual.append(x)
x = block(x)
else:
x = block(x)
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
for i in range(self.num_mid_blocks):
x = self.blocks[i](x, mid_c, pos, None)
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
residual[0] = 0.0
for i, block in enumerate(self.up_res_blocks):
if isinstance(block, nn.ConvTranspose2d):
x = block(x) + residual.pop()
else:
x = block(x)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder, "encoder.")
ModelLoader().load(decoder, "decoder.")
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
def forward(self, x, t, y, s=None, classify_layer=None):
if s is None:
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
if classify_layer is not None:
return None, s
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiT_jointtraining(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -0,0 +1,274 @@
import torch
import torch.nn as nn
import math
from numba.cuda.cudadrv.devicearray import lru_cache
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
flex_attention = torch.compile(flex_attention)
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = False,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = self.norm(x)
x = modulate(x, shift, scale)
x = self.linear(x)
return x
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
self.fc1 = nn.Linear(dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, dim)
self.act = nn.GELU(approximate="tanh")
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16):
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
x_pos = x_pos.reshape(-1)
y_pos = y_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1)
return freqs_cis
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
# import pdb; pdb.set_trace()
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q).to(q.dtype)
k = self.k_norm(k).to(k.dtype)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
# x = flex_attention(q, k, v, block_mask=mask)
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
x = scaled_dot_product_attention(q, k, v, mask)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class DiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class ConDiT(nn.Module):
def __init__(
self,
in_channels=4,
out_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2)
self.num_cond_blocks = num_cond_blocks
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
@lru_cache
def fetch_pos(self, height, width, device):
pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...]
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H, W, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
if s is None:
# semantic encoder
s = self.s_embedder(x) + pos
c = nn.functional.silu(t + y)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,314 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
# s = nn.functional.silu(t + s)
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
x = torch.cat((x, s), dim=-1)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, c, pos, None)
x = self.final_layer(x, c)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,340 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConvBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3):
super().__init__()
self.hidden_size = hidden_size
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
attn_x = self.attn(attn_x)
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
x = x + gate_msa * attn_x
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
kernel_size=3,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,339 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConvBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.hidden_size = hidden_size
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
attn_x = self.attn(attn_x)
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
x = x + gate_msa * attn_x
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,313 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,314 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,429 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
from src.utils.model_loader import ModelLoader
from src.utils.no_grad import no_grad
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiTEncoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
return (pos_rope, pos_ape)
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
def forward(self, x, t, y, mask=None):
B, _, H, W = x.shape
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
s = self.s_embedder(x)
# s = s + pos_ape.to(s.dtype)
for i in range(self.num_blocks):
s = self.blocks[i](s, c, pos_rope, mask)
return None, s
class FlattenDiTDecoder(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
weight_path=None,
load_ema=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)]
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
# s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
s = torch.nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
stride=self.patch_size)
return x
class FlattenDiT(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
joint_training=False,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
ModelLoader().load(encoder)
if not joint_training:
self.encoder = self.encoder.to(torch.bfloat16)
no_grad(self.encoder)
self.joint_training = joint_training
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s
class FlattenDiTScalingEncoder(nn.Module):
def __init__(
self,
encoder:FlattenDiTEncoder,
decoder:FlattenDiTDecoder,
):
super().__init__()
self.encoder = encoder
self.decoder = decoder
no_grad(self.decoder)
if self.encoder.weight_path:
weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu'))
if self.encoder.load_ema:
prefix = "ema_denoiser."
else:
prefix = "denoiser."
for k, v in self.encoder.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
print(f"Failed to copy {prefix+k} to denoiser weight")
if self.decoder.weight_path:
weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu'))
if self.decoder.load_ema:
prefix = "ema_denoiser."
else:
prefix = "denoiser."
for k, v in self.decoder.state_dict().items():
if "blocks." in k:
blockid = int(k.split("blocks.")[-1][0])
k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}")
try:
v.copy_(weight["state_dict"][prefix+k])
except:
print(f"Failed to copy {prefix+k} to denoiser weight")
self.decoder = decoder.to(torch.bfloat16)
def forward(self, x, t, y, s=None):
if s is None:
_, s = self.encoder(x, t, y)
x = self.decoder(x, t, y, s)
return x, s

View File

@@ -0,0 +1,334 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenMLPBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = FeedForward(hidden_size, mlp_hidden_dim)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([])
for i in range(self.num_cond_blocks):
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
for i in range(self.num_blocks-self.num_cond_blocks):
self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups))
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.s_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, None)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,321 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.s_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.s_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device)
s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2)
s = self.s_embedder(s)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos_s, mask)
s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size)
s = torch.permute(s, (0, 3, 1, 2))
s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False)
s = torch.permute(s, (0, 2, 3, 1))
s = s.view(B, -1, self.hidden_size)
s = nn.functional.silu(t + s)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos_x, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,306 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, masks=None):
if masks is None:
masks = [None, ]*self.num_blocks
if isinstance(masks, torch.Tensor):
masks = masks.unbind(0)
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
masks = masks + [None]*(self.num_blocks-len(masks))
B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C)
condition = nn.functional.silu(t + y)
for i, block in enumerate(self.blocks):
x = block(x, condition, pos, masks[i])
x = self.final_layer(x, condition)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x

View File

@@ -0,0 +1,311 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device, dtype):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device, dtype)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, masks=None):
if masks is None:
masks = [None, ]*self.num_blocks
if isinstance(masks, torch.Tensor):
masks = masks.unbind(0)
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
masks = masks + [None]*(self.num_blocks-len(masks))
B, _, H, W = x.shape
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
x = self.x_embedder(x)
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
B, L, C = x.shape
t = self.t_embedder(t.view(-1)).view(B, -1, C)
y = self.y_embedder(y).view(B, 1, C)
condition = nn.functional.silu(t + y)
for i, block in enumerate(self.blocks):
x = block(x, condition, pos, masks[i])
x = self.final_layer(x, condition)
x0, v = x.chunk(2, dim=-1)
x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
if self.training:
return v, x0
else:
return v

View File

@@ -0,0 +1,308 @@
import functools
from typing import Tuple
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from torch.nn.modules.module import T
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from torch.nn.functional import scaled_dot_product_attention
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class Embed(nn.Module):
def __init__(
self,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size):
super().__init__()
self.embedding_table = nn.Embedding(num_classes, hidden_size)
self.num_classes = num_classes
def forward(self, labels,):
embeddings = self.embedding_table(labels)
return embeddings
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 2*hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
return x
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
y_pos = torch.linspace(0, scale, height)
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
freqs_cis = freqs_cis.reshape(height*width, -1)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
freqs_cis = freqs_cis[None, :, None, :]
# xq : B N H Hc
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = RMSNorm,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FlattenDiTBlock(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, pos, mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlattenConDiT(nn.Module):
def __init__(
self,
in_channels=4,
num_groups=12,
hidden_size=1152,
num_blocks=18,
num_cond_blocks=4,
patch_size=2,
num_classes=1000,
learn_sigma=True,
deep_supervision=0,
weight_path=None,
load_ema=False,
):
super().__init__()
self.deep_supervision = deep_supervision
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.num_blocks = num_blocks
self.num_cond_blocks = num_cond_blocks
self.patch_size = patch_size
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
self.weight_path = weight_path
self.load_ema = load_ema
self.blocks = nn.ModuleList([
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
])
self.initialize_weights()
self.precompute_pos = dict()
def fetch_pos(self, height, width, device):
if (height, width) in self.precompute_pos:
return self.precompute_pos[(height, width)].to(device)
else:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
self.precompute_pos[(height, width)] = pos
return pos
def initialize_weights(self):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# # Zero-out adaLN modulation layers in SiT blocks:
# for block in self.blocks:
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, y, s=None, mask=None):
B, _, H, W = x.shape
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
y = self.y_embedder(y).view(B, 1, self.hidden_size)
c = nn.functional.silu(t + y)
if s is None:
s = self.x_embedder(x)
for i in range(self.num_cond_blocks):
s = self.blocks[i](s, c, pos, mask)
s = nn.functional.silu(t + s)
x = self.x_embedder(x)
for i in range(self.num_cond_blocks, self.num_blocks):
x = self.blocks[i](x, s, pos, None)
x = self.final_layer(x, s)
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return x, s

View File

@@ -0,0 +1,160 @@
import torch
import torch.nn as nn
import math
from torch.nn.init import zeros_
from src.models.denoiser.base_model import BaseModel
from src.ops.triton_kernels.function import DCNFunction
def modulate(x, shift, scale):
return x * (1 + scale[:, None, None]) + shift[:, None, None]
class PatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer = None,
bias: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
b, h, w, c = x.shape
x = x.view(b, h*w, c)
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
x = x.view(b, h, w, c)
return x
class MultiScaleDCN(nn.Module):
def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
super().__init__()
self.in_channels = in_channels
self.groups = groups
self.channels = channels
self.kernels = kernels
self.v = nn.Linear(in_channels, groups * channels, bias=True)
self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
self.out = nn.Linear(groups * channels, in_channels)
self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
self.max_scale = 6
self._init_weights()
def _init_weights(self):
zeros_(self.qk_deformables.weight.data)
zeros_(self.qk_scales.weight.data)
zeros_(self.qk_deformables.bias.data)
zeros_(self.qk_weights.weight.data)
zeros_(self.v.bias.data)
zeros_(self.out.bias.data)
num_prior = int(self.kernels ** 0.5)
dx = torch.linspace(-1, 1, num_prior, device="cuda")
dy = torch.linspace(-1, 1, num_prior, device="cuda")
dxy = torch.meshgrid([dx, dy], indexing="xy")
dxy = torch.stack(dxy, dim=-1)
dxy = dxy.view(-1, 2)
self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
for i in range(self.groups):
scale = (i+1)/self.groups - 0.0001
inv_scale = math.log((scale)/(1-scale))
self.deformables_scale.data[..., i, :, :] = inv_scale
def forward(self, x):
B, H, W, _ = x.shape
v = self.v(x).view(B, H, W, self.groups, self.channels)
deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
out = DCNFunction.apply(v, deformables, weights)
out = out.view(B, H, W, -1)
out = self.out(out)
return out
class FlowDCNBlock(nn.Module):
def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True):
super().__init__()
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass)
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FlowDCN(BaseModel):
def __init__(self, deformable_biass=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.blocks = nn.ModuleList([
FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks)
])
self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True)
self.initialize_weights()
def forward(self, x, t, y):
batch_size, _, height, width = x.shape[0]
x = self.x_embedder(x) # (N, D, h, w)
x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1)
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
B, L, C = x.shape
x = x.view(B, height//self.patch_size, width//self.patch_size, C)
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = x.view(B, L, C)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size)
if self.learn_sigma:
x, _ = torch.split(x, self.out_channels // 2, dim=1)
return x

132
src/models/encoder.py Normal file
View File

@@ -0,0 +1,132 @@
import torch
import copy
import os
import timm
import transformers
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from torchvision.transforms import Normalize
class RandViT(nn.Module):
def __init__(self, model_id, weight_path:str=None):
super(RandViT, self).__init__()
self.encoder = timm.create_model(
model_id,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class MAE(nn.Module):
def __init__(self, model_id, weight_path:str):
super(MAE, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class DINO(nn.Module):
def __init__(self, model_id, weight_path:str):
super(DINO, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([ 0.0,
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([ 1.0,
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class CLIP(nn.Module):
def __init__(self, model_id, weight_path:str):
super(CLIP, self).__init__()
self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
self.shifts = nn.Parameter(torch.tensor([0.0,
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0,
]), requires_grad=False)
def forward(self, x):
x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder(x)['last_hidden_state'][:, 1:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
class DINOv2(nn.Module):
def __init__(self, model_id, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
return feature

81
src/models/vae.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
import subprocess
import lightning.pytorch as pl
import logging
logger = logging.getLogger(__name__)
def class_fn_from_str(class_str):
class_module, from_class = class_str.rsplit(".", 1)
class_module = __import__(class_module, fromlist=[from_class])
return getattr(class_module, from_class)
class BaseVAE(torch.nn.Module):
def __init__(self, scale=1.0, shift=0.0):
super().__init__()
self.model = torch.nn.Identity()
self.scale = scale
self.shift = shift
def encode(self, x):
return x/self.scale+self.shift
def decode(self, x):
return (x-self.shift)*self.scale
# very bad bugs with nearest sampling
class DownSampleVAE(BaseVAE):
def __init__(self, down_ratio, scale=1.0, shift=0.0):
super().__init__()
self.model = torch.nn.Identity()
self.scale = scale
self.shift = shift
self.down_ratio = down_ratio
def encode(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
return x/self.scale+self.shift
def decode(self, x):
x = (x-self.shift)*self.scale
x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
return x
class LatentVAE(BaseVAE):
def __init__(self, precompute=False, weight_path:str=None):
super().__init__()
self.precompute = precompute
self.model = None
self.weight_path = weight_path
from diffusers.models import AutoencoderKL
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
self.scaling_factor = self.model.config.scaling_factor
@torch.no_grad()
def encode(self, x):
assert self.model is not None
if self.precompute:
return x.mul_(self.scaling_factor)
return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
@torch.no_grad()
def decode(self, x):
assert self.model is not None
return self.model.decode(x.div_(self.scaling_factor)).sample
def uint82fp(x):
x = x.to(torch.float32)
x = (x - 127.5) / 127.5
return x
def fp2uint8(x):
x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
return x

View File

@@ -0,0 +1,346 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <cuda_runtime.h>
#include <cuda_fp16.hpp>
#include <cuda_bf16.hpp>
#include <ATen/cuda/CUDAContext.h>
#include <pybind11/pybind11.h>
#include <cooperative_groups/memcpy_async.h>
#include <cuda/pipeline>
namespace cg = cooperative_groups;
template<typename scalar_t>
__device__ __always_inline int toInt(scalar_t val);
template<>
__device__ __always_inline int toInt(float val){
return static_cast<int>(val);
}
template<>
__device__ __always_inline int toInt(half val){
return __half2int_rz(val);
}
template<typename scalar_t>
__device__ __always_inline scalar_t fromInt(int val);
template<>
__device__ __always_inline float fromInt(int val){
return static_cast<float>(val);
}
template<>
__device__ __always_inline half fromInt(int val){
return __int2half_rz(val);
}
template<typename scalar_t>
__device__ __always_inline scalar_t constVal(float val);
template<>
__device__ __always_inline float constVal<float>(float val) {
return (float)val;
}
template<>
__device__ __always_inline half constVal<half>(float val) {
return __float2half(val); // Using float to half conversion
}
template<>
__device__ __always_inline nv_bfloat16 constVal<nv_bfloat16>(float val){
return __float2bfloat16(val);
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename scalar_t, typename vec2_t, int pipeline_stages, int TILE_C, int TILE_THREADS>
__global__ void dcn_backward_pipeline_kernel(
const int H,
const int W,
const int G,
const int K,
const int C,
scalar_t* ptr_values,
scalar_t* ptr_deformables,
scalar_t* ptr_weights,
scalar_t* ptr_grad_out,
scalar_t* ptr_grad_values,
scalar_t* ptr_grad_deformables,
scalar_t* ptr_grad_weights
) {
auto block = cg::this_thread_block();
auto self_thread = cg::this_thread();
auto tile_threads = cg::tiled_partition<TILE_THREADS>(block);
int local_thread_id = block.thread_rank();
int local_tile_id = tile_threads.meta_group_rank();
int num_local_tiles = tile_threads.meta_group_size();
int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id;
extern __shared__ int shm[];
auto GradBuffer = reinterpret_cast<scalar_t*>(shm);
scalar_t* Buffer = reinterpret_cast<scalar_t*>(shm) + num_local_tiles*C;
if(global_tile_id >= H*W*G) return;
int bid = block.group_index().y;
int gid = global_tile_id % G;
int wid = global_tile_id / G % W;
int hid = global_tile_id / G / W;
int globale_offset = bid*H*W*G*C + global_tile_id*C;
cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C);
int shared_offset[pipeline_stages];
for (int s = 0; s < pipeline_stages; ++s) {
shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4);
}
auto pipeline = cuda::make_pipeline();
const int num_tiles_per_thread = C/TILE_C/TILE_THREADS;
for(int k=0; k<K; k++) {
int offset = bid * K * H * W * G + hid * W * K * G + wid * K * G + gid * K + k;
scalar_t x, y, weight;
if (tile_threads.thread_rank() == 0) {
x = ptr_deformables[offset*2] + fromInt<scalar_t>(wid);
y = ptr_deformables[offset*2 + 1] + fromInt<scalar_t>(hid);
// x = fromInt<scalar_t>(wid);
// y = fromInt<scalar_t>(hid);
weight = ptr_weights[offset];
}
tile_threads.sync();
x = tile_threads.shfl(x, 0);
y = tile_threads.shfl(y, 0);
weight = tile_threads.shfl(weight, 0);
int floor_x = toInt<scalar_t>(x);
int floor_y = toInt<scalar_t>(y);
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
scalar_t dodx = constVal<scalar_t>(0.0f);
scalar_t dody = constVal<scalar_t>(0.0f);
scalar_t dodw = constVal<scalar_t>(0.0f);
int start_c = tile_threads.thread_rank() * (C / TILE_THREADS);
bool tl_flag = (floor_x >=0) and (floor_x <W) and (floor_y>=0) and (floor_y<H);
bool tr_flag = (ceil_x >=0) and (ceil_x <W) and (floor_y>=0) and (floor_y<H);
bool bl_flag = (floor_x >=0) and (floor_x <W) and (ceil_y>=0) and (ceil_y<H);
bool br_flag = (ceil_x >=0) and (ceil_x <W) and (ceil_y>=0) and (ceil_y<H);
int tl_global_base = (bid * H * W * G + floor_y * W * G + floor_x * G + gid)*C + start_c;
int tr_global_base = (bid * H * W * G + floor_y * W * G + ceil_x * G + gid)*C + start_c;
int bl_global_base = (bid * H * W * G + ceil_y * W * G + floor_x * G + gid)*C +start_c;
int br_global_base = (bid * H * W * G + ceil_y * W * G + ceil_x * G + gid)*C +start_c;
auto asmem_load_fn = [&](int shm_offset, int hbm_offset, bool flag){
if(flag){
cuda::memcpy_async(Buffer + shm_offset, ptr_values + hbm_offset,
TILE_C * sizeof(scalar_t), pipeline);
}else{
memset(Buffer+shm_offset, TILE_C, sizeof(scalar_t));
}
};
// pipeline-compute&load
for (int compute_n = 0, fetch_n=0; compute_n < num_tiles_per_thread; compute_n++) {
for (; fetch_n < compute_n + pipeline_stages and fetch_n < num_tiles_per_thread; fetch_n++) {
pipeline.producer_acquire();
int buffer_offset = shared_offset[fetch_n % pipeline_stages];
// tl
asmem_load_fn(buffer_offset, tl_global_base + fetch_n * TILE_C, tl_flag);
// tr
asmem_load_fn(buffer_offset+TILE_C, tr_global_base + fetch_n * TILE_C, tr_flag);
// bl
asmem_load_fn(buffer_offset+TILE_C*2, bl_global_base + fetch_n * TILE_C, bl_flag);
// br
asmem_load_fn(buffer_offset+TILE_C*3, br_global_base + fetch_n * TILE_C, br_flag);
pipeline.producer_commit();
}
pipeline.consumer_wait();
int buffer_id = compute_n % pipeline_stages;
int ibuffer_offset = shared_offset[buffer_id];
int gbuffer_offset = local_tile_id * C + start_c + compute_n * TILE_C;
for (int j = 0; j < TILE_C; j+=2) {
if(tl_flag){
// tl
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1];
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
{
vec2_t vtl_di;
vtl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
vtl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1];
atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di);
}
}
if(tr_flag){
// tr
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1];
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vtr_di;
vtr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
vtr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di);
}
}
if(bl_flag){
// bl
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vbl_di;
vbl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
vbl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di);
}
}
if(br_flag){
// tr
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
{
vec2_t vbr_di;
vbr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
vbr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di);
}
}
}
pipeline.consumer_release();
}
for (int i = TILE_THREADS>>1; i > 0; i/=2) {
dodx = dodx + tile_threads.shfl_down(dodx, i);
dody = dody + tile_threads.shfl_down(dody, i);
dodw = dodw + tile_threads.shfl_down(dodw, i);
}
if (tile_threads.thread_rank() == 0) {
cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline);
cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline);
cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline);
}
}
}
using namespace torch;
template<int pipeline_stages, int TILE_C, int TILE_THREADS, int THREADS>
void backward(const int B,
const int H,
const int W,
const int G,
const int K,
const int C,
torch::Tensor values,
torch::Tensor deformables,
torch::Tensor weights,
torch::Tensor grad_out,
torch::Tensor grad_values,
torch::Tensor grad_deformables,
torch::Tensor grad_weights
) {
int num_local_tiles =(THREADS/TILE_THREADS);
int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(num_global_tiles, B);
int deformable_shm_size = 0;
int grad_out_shm_size = num_local_tiles*C;
int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS);
int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size;
// printf("shm_size: %d\n", shm_size/512);
// printf("pipeline_size: %d\n", pipeline_shm_size/512);
// printf("grad_out_size: %d\n", grad_out_shm_size/512);
switch (values.type().scalarType()) {
case at::ScalarType::Half:
return dcn_backward_pipeline_kernel<half, half2, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(half)>>>(
H, W, G, K, C,
reinterpret_cast<half*>(values.data_ptr<at::Half>()),
reinterpret_cast<half*>(deformables.data_ptr<at::Half>()),
reinterpret_cast<half*>(weights.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_out.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_values.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_deformables.data_ptr<at::Half>()),
reinterpret_cast<half*>(grad_weights.data_ptr<at::Half>())
);
// case at::ScalarType::BFloat16:
// return dcn_backward_pipeline_kernel<nv_bfloat16, nv_bfloat162, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(nv_bfloat16)>>>(
// H, W, G, K, C,
// reinterpret_cast<nv_bfloat16*>(values.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(deformables.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(weights.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_out.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_values.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_deformables.data_ptr<at::BFloat16>()),
// reinterpret_cast<nv_bfloat16*>(grad_weights.data_ptr<at::BFloat16>())
// );
default:
printf("running error");
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, "");
m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, "");
m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, "");
m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, "");
m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, "");
m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, "");
m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, "");
m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, "");
m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, "");
m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, "");
m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, "");
m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, "");
m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, "");
m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, "");
m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, "");
// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, "");
// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, "");
// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, "");
m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, "");
m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, "");
m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, "");
m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, "");
}

View File

@@ -0,0 +1,289 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_pipeline_primitives.h>
#include <cuda_fp16.hpp>
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b));
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA>
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = 0;
ptr_a += stride;
}
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
// __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
for(int j=0; j<num_transfers; j++){
if((base_c+j) < C){
// __pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
math_buffer[job_id][j] = (math_t)*(ptr_value + offset2 +j);
}
}
}
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
// loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
// loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
#pragma unroll
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
}
}
}
__syncthreads();
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
__shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM/transfer_length;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
}
__syncthreads();
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
// int offset1 = job_id*num_transfers;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
#pragma unroll
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
}
}
}
}
template<int L, int C_BLOCK_DIM, int THREADS>
void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_16<at::Half, at::Half, 4, 9, L, (C_BLOCK_DIM)><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::BFloat16:
return dcn_forward_kernel_16<at::BFloat16, at::BFloat16, 4, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<at::BFloat16>(),
deformables.data_ptr<at::BFloat16>(),
weights.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>());
case at::ScalarType::Float:
return dcn_forward_kernel<at::Half, float, 2, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
H, W, C,
value.data_ptr<float>(),
deformables.data_ptr<float>(),
weights.data_ptr<float>(),
out.data_ptr<float>());
default:
printf("running error");
}
}
// PyBind11 bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
}

View File

@@ -0,0 +1,309 @@
#include <iostream>
#include <string>
#include <fstream>
#include <chrono>
#include <iostream>
#include <vector_types.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/OpMathType.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_pipeline_primitives.h>
#include <cuda_fp16.hpp>
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/memcpy_async.h>
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b) * weight);
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA, typename TB>
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = (TA)((*ptr_b));
ptr_a += stride_a;
ptr_b += stride_b;
}
}
template <typename TA>
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
#pragma unroll
for(int i=0; i<n; i++){
*ptr_a = 0;
ptr_a += stride;
}
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int K, int BLOCK_DIM>
__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
extern __shared__ int shm[];
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
math_buffer[job_id*BLOCK_DIM +j] = (math_t)*(ptr_value + offset2 +j);
}
}
}
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
#pragma unroll
for(int j=0; j<BLOCK_DIM; j++){
if((base_c+j) < C){
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
}
}
}
__syncthreads();
}
// B, H, W, C, BLOCK_DIM must be multiple of C
template <typename math_t, typename scalar_t, int transfer_length, int K, int BLOCK_DIM>
__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
int work_id = threadIdx.x;
int bid = blockIdx.z;
int gid = blockIdx.y;
int G = gridDim.y;
int c_blockid = blockIdx.x;
int work_load = (H*W/blockDim.x);
extern __shared__ int shm[];
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
scalar_t* io_buffer = reinterpret_cast<scalar_t*>(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t);
math_t register_bufferA[BLOCK_DIM] = {0};
int base_c = c_blockid*BLOCK_DIM;
int num_transfers = BLOCK_DIM/transfer_length;
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
work_load = (H*W)/blockDim.x;
int offset = 0;
for(int i=0; i<work_load; i++){
int job_id = (work_id*work_load+i);
int hid = job_id/W;
int wid = job_id%W;
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
#pragma unroll
for(int k=0; k<K; k++){
// read weights to register
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
math_t weight = *(ptr_weights + offset);
// read deformables to register
offset = offset*2;
math_t x = *(ptr_deformables + offset) + wid;
math_t y = *(ptr_deformables + offset + 1) + hid;
int floor_x = x;
int floor_y = y;
int ceil_x = floor_x + 1;
int ceil_y = floor_y + 1;
// reset A buffer and top left
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
}
// load top right
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
}
// load bottom left
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
}
// load bottom right
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
}
}
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
}
__syncthreads();
#pragma unroll
for(int i=0; i<work_load; i++){
int job_id = work_load*work_id + i;
// int offset1 = job_id*num_transfers;
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
#pragma unroll
for(int j=0; j<num_transfers; j++){
if((base_c+j*transfer_length) < C){
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
}
}
}
}
template<int C_BLOCK_DIM, int THREADS>
void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_register<at::Half, at::Half, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::Float:
return dcn_forward_kernel_register<at::Half, float, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<float>(),
deformables.data_ptr<float>(),
weights.data_ptr<float>(),
out.data_ptr<float>());
default:
printf("running error");
}
}
template<int C_BLOCK_DIM, int THREADS>
void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
dim3 launch_threads_per_block(THREADS);
dim3 launch_blocks(NUM_C_BLOCK, G, B);
int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half);
switch (value.type().scalarType()) {
case at::ScalarType::Half:
return dcn_forward_kernel_pipeline<at::Half, at::Half, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::Half>(),
deformables.data_ptr<at::Half>(),
weights.data_ptr<at::Half>(),
out.data_ptr<at::Half>());
case at::ScalarType::BFloat16:
return dcn_forward_kernel_pipeline<at::BFloat16, at::BFloat16, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
H, W, C,
value.data_ptr<at::BFloat16>(),
deformables.data_ptr<at::BFloat16>(),
weights.data_ptr<at::BFloat16>(),
out.data_ptr<at::BFloat16>());
default:
printf("running error");
}
}
// PyBind11 bindings
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward");
m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward");
m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
}

View File

@@ -0,0 +1,95 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_channels_per_group
C: tl.constexpr, # num_groups
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
weights_ptr, # weights [B, H, W, G, K]
out_ptr, # out [B, H, W, G, C]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for block_base in tl.static_range(0, C, BLOCK_SIZE):
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
for k in tl.static_range(K):
deformable_offset = (common_offset * K + k) * 2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input = tl_block_input * tl_weight
# load top right
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input = tr_block_input * tr_weight
# load bottom left
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input = bl_block_input * bl_weight
# load bottom right
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input = br_block_input * br_weight
# sampled
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
weighted_sampled_input = sampled_input * weight
buffer = buffer + weighted_sampled_input
# store to out_ptr
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)

View File

@@ -0,0 +1,126 @@
import time
import dcn_cuda_backward
import dcn_cuda_forward
import math
import torch
from typing import Any
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_fwd, custom_bwd
from .forward import forward_kernel
class DCNFunction(Function):
BP_FUNCS = [
dcn_cuda_backward.backward_p1_c2_tile16_thread128,
dcn_cuda_backward.backward_p1_c4_tile16_thread128,
dcn_cuda_backward.backward_p2_c2_tile16_thread128,
dcn_cuda_backward.backward_p1_c2_tile16_thread256,
dcn_cuda_backward.backward_p1_c4_tile16_thread256,
dcn_cuda_backward.backward_p2_c2_tile16_thread256,
dcn_cuda_backward.backward_p1_c2_tile16_thread384,
dcn_cuda_backward.backward_p1_c4_tile16_thread384,
dcn_cuda_backward.backward_p2_c2_tile16_thread384,
dcn_cuda_backward.backward_p1_c2_tile16_thread512,
dcn_cuda_backward.backward_p1_c4_tile16_thread512,
dcn_cuda_backward.backward_p2_c2_tile16_thread512,
dcn_cuda_backward.backward_p1_c2_tile16_thread768,
dcn_cuda_backward.backward_p1_c4_tile16_thread768,
dcn_cuda_backward.backward_p2_c2_tile16_thread768,
dcn_cuda_backward.backward_p1_c2_tile32_thread128,
dcn_cuda_backward.backward_p1_c2_tile32_thread256,
dcn_cuda_backward.backward_p1_c2_tile32_thread384,
dcn_cuda_backward.backward_p1_c2_tile32_thread512,
]
FW_FUNCS = [
dcn_cuda_forward.dcn_forward_l256_c4,
dcn_cuda_forward.dcn_forward_l256_c8,
dcn_cuda_forward.dcn_forward_l256_c16,
]
BP_TABLES = dict()
FW_TABLES = dict()
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, values, deformables, weights) -> Any:
B, H, W, G, C = values.shape
func = DCNFunction.find_fw_funcs(values, deformables, weights)
out = torch.zeros_like(values)
func(B, G, C, H, W, values, deformables, weights, out)
return out
@staticmethod
def find_fw_funcs(values, deformables, weights):
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
hash_value = 10000 * B + 100 * H + W + 1000 * G
if hash_value in DCNFunction.FW_TABLES.keys():
return DCNFunction.FW_TABLES[hash_value]
print("missing")
candicate_func = None
min_t = 999.0
outs = torch.zeros_like(values)
for func in DCNFunction.FW_FUNCS:
t = []
for i in range(100):
torch.cuda.synchronize()
start_t = time.time()
func(B, G, C, H, W, values, deformables, weights, outs)
torch.cuda.synchronize()
t.append(time.time() - start_t)
t = t[-50:]
t = sum(t) / len(t)
if t < min_t:
min_t = t
DCNFunction.FW_TABLES[hash_value] = func
candicate_func = func
assert candicate_func is not None
print(candicate_func)
return candicate_func
@staticmethod
def find_bp_funcs(values, deformables, weights, grad_out):
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
hash_value = 10000 * B + 100 * H + W + 1000 * G
if hash_value in DCNFunction.BP_TABLES.keys():
return DCNFunction.BP_TABLES[hash_value]
print("missing")
candicate_func = None
min_t = 999.0
grad_values = torch.zeros_like(values)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
for func in DCNFunction.BP_FUNCS:
t = []
for i in range(100):
torch.cuda.synchronize()
start_t = time.time()
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
torch.cuda.synchronize()
t.append(time.time() - start_t)
t = t[-50:]
t = sum(t) / len(t)
if t < min_t:
min_t = t
DCNFunction.BP_TABLES[hash_value] = func
candicate_func = func
assert candicate_func is not None
print(candicate_func)
return candicate_func
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx: Any, *grad_outputs: Any) -> Any:
grad_out = grad_outputs[0]
values, deformables, weights = ctx.saved_tensors
B, H, W, G, C = values.shape
B, H, W, G, K = weights.shape
func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out)
grad_values = torch.zeros_like(values)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
return grad_values, grad_deformables, grad_weights

View File

@@ -0,0 +1,59 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='dcn_cuda_forward',
ext_modules=[
CUDAExtension('dcn_cuda_forward', ['./forward.cu',],
extra_compile_args={'cxx': [], 'nvcc': [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"--use_fast_math",
"-O3",
]}
),
],
cmdclass={
'build_ext': BuildExtension
}
)
setup(
name='dcn_cuda_backward',
ext_modules=[
CUDAExtension('dcn_cuda_backward', ['./backward.cu',],
extra_compile_args={'cxx': [], 'nvcc': [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
"--use_fast_math",
"-O3",
]}
),
],
cmdclass={
'build_ext': BuildExtension
}
)
# setup(
# name='mycuda',
# ext_modules=[
# CUDAExtension('mycuda', ['./backward.cu',],
# extra_compile_args={'cxx': [], 'nvcc': [
# "-O3",
# "-DCUDA_HAS_FP16=1",
# "-D__CUDA_NO_HALF_OPERATORS__",
# "-D__CUDA_NO_HALF_CONVERSIONS__",
# "-D__CUDA_NO_HALF2_OPERATORS__",
# ]}
# ),
# ],
# cmdclass={
# 'build_ext': BuildExtension
# }
# )

View File

View File

@@ -0,0 +1,124 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def backward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_groups
C: tl.constexpr, # num_channels_per_group
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
weights_ptr, # weights [B, H, W, G, K]
grad_ptr, # out [B, H, W, G, C]
grad_input_ptr, # input features [B, H, W, G, C]
grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
grad_weights_ptr, # weights [B, H, W, G, K]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for k in tl.static_range(K):
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
deformable_offset = (common_offset * K + k)*2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
for block_base in tl.static_range(0, C, BLOCK_SIZE):
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
dods = weight*grad
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + tl_block_input_dot_grad * tl_weight
dodtl = dods * tl_weight
tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
dodw = dodw + tr_block_input_dot_grad*tr_weight
dodtr = dods * tr_weight
tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
dodw = dodw + bl_block_input_dot_grad*bl_weight
dodbl = dods * bl_weight
tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
dodw = dodw + br_block_input_dot_grad*br_weight
dodbr = dods * br_weight
tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
dodx = dodx * weight
dody = dody * weight
tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)

View File

@@ -0,0 +1,94 @@
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
# triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
],
key=['B', 'H', 'W', 'G', 'C', 'K'],
)
@triton.jit
def forward_kernel(
B: tl.constexpr,
H: tl.constexpr, # image_size_h
W: tl.constexpr, # image_size_w
G: tl.constexpr, # num_channels_per_group
C: tl.constexpr, # num_groups
K: tl.constexpr, # kernel size
input_ptr, # input features [B, H, W, G, C]
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
weights_ptr, # weights [B, H, W, G, K]
out_ptr, # out [B, H, W, G, C]
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
):
pid = tl.program_id(0)
wid = pid % W
hid = pid // W % H
gid = pid // (W * H) % G
bid = pid // (W * H * G)
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
batch_base = bid * H * W * G * C
for block_base in tl.static_range(0, C, BLOCK_SIZE):
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
block_mask = (block_offset < C) & id_mask
for k in tl.static_range(K):
deformable_offset = (common_offset * K + k) * 2
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
floor_x = x.to(tl.int32)
floor_y = y.to(tl.int32)
ceil_x = floor_x + 1
ceil_y = floor_y + 1
# load top left
tl_weight = (ceil_x - x) * (ceil_y - y)
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
# load top right
tr_weight = (x - floor_x) * (ceil_y - y)
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
# load bottom left
bl_weight = (ceil_x - x) * (y - floor_y)
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
# load bottom right
br_weight = (x - floor_x) * (y - floor_y)
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
# load dynamic weight and mask
weights_offset = common_offset*K + k
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
tl_block_input = tl_block_input * tl_weight
# load top right
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
tr_block_input = tr_block_input * tr_weight
# load bottom left
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
bl_block_input = bl_block_input * bl_weight
# load bottom right
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
br_block_input = br_block_input * br_weight
# sampled
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
weighted_sampled_input = sampled_input * weight
buffer = buffer + weighted_sampled_input
# store to out_ptr
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)

View File

@@ -0,0 +1,48 @@
import torch
import triton
from typing import Any
from torch.autograd import Function
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
from .forward import forward_kernel
from .backward import backward_kernel
class DCNFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx: Any, inputs, deformables, weights) -> Any:
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
out = torch.zeros_like(inputs)
grid = lambda META: (B * H * W * G,)
forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
ctx.save_for_backward(inputs, deformables, weights)
return out
@staticmethod
@custom_bwd
def backward(ctx: Any, *grad_outputs: Any) -> Any:
grad_output = grad_outputs[0].contiguous()
inputs, deformables, weights = ctx.saved_tensors
B, H, W, G, C = inputs.shape
_, _, _, _, K, _ = deformables.shape
grad_inputs = torch.zeros_like(inputs)
grad_deformables = torch.zeros_like(deformables)
grad_weights = torch.zeros_like(weights)
grid = lambda META: (B * H * W * G,)
backward_kernel[grid](
B, H, W, G, C, K,
inputs,
deformables,
weights,
grad_output,
grad_inputs,
grad_deformables,
grad_weights,
)
return (grad_inputs, grad_deformables, grad_weights)

0
src/plugins/__init__.py Normal file
View File

70
src/plugins/bd_env.py Normal file
View File

@@ -0,0 +1,70 @@
import torch
import os
import socket
from typing_extensions import override
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.plugins.environments.lightning import LightningEnvironment
class BDEnvironment(LightningEnvironment):
pass
# def __init__(self) -> None:
# super().__init__()
# self._global_rank: int = 0
# self._world_size: int = 1
#
# @property
# @override
# def creates_processes_externally(self) -> bool:
# """Returns whether the cluster creates the processes or not.
#
# If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
# process launcher/job scheduler and Lightning will not launch new processes.
#
# """
# return "LOCAL_RANK" in os.environ
#
# @staticmethod
# @override
# def detect() -> bool:
# assert "ARNOLD_WORKER_0_HOST" in os.environ.keys()
# assert "ARNOLD_WORKER_0_PORT" in os.environ.keys()
# return True
#
# @override
# def world_size(self) -> int:
# return self._world_size
#
# @override
# def set_world_size(self, size: int) -> None:
# self._world_size = size
#
# @override
# def global_rank(self) -> int:
# return self._global_rank
#
# @override
# def set_global_rank(self, rank: int) -> None:
# self._global_rank = rank
# rank_zero_only.rank = rank
#
# @override
# def local_rank(self) -> int:
# return int(os.environ.get("LOCAL_RANK", 0))
#
# @override
# def node_rank(self) -> int:
# return int(os.environ.get("ARNOLD_ID"))
#
# @override
# def teardown(self) -> None:
# if "WORLD_SIZE" in os.environ:
# del os.environ["WORLD_SIZE"]
#
# @property
# def main_address(self) -> str:
# return os.environ.get("ARNOLD_WORKER_0_HOST")
#
# @property
# def main_port(self) -> int:
# return int(os.environ.get("ARNOLD_WORKER_0_PORT"))

0
src/utils/__init__.py Normal file
View File

13
src/utils/copy.py Normal file
View File

@@ -0,0 +1,13 @@
import torch
@torch.no_grad()
def copy_params(src_model, dst_model):
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
dst_param.data.copy_(src_param.data)
@torch.no_grad()
def swap_tensors(tensor1, tensor2):
tmp = torch.empty_like(tensor1)
tmp.copy_(tensor1)
tensor1.copy_(tensor2)
tensor2.copy_(tmp)

Some files were not shown because too many files have changed in this diff Show More