submit code
This commit is contained in:
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
8
.idea/DDT.iml
generated
Normal file
8
.idea/DDT.iml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/DDT.iml" filepath="$PROJECT_DIR$/.idea/DDT.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
108
configs/repa_flatten_condit22_fixt_xl.yaml
Normal file
108
configs/repa_flatten_condit22_fixt_xl.yaml
Normal file
@@ -0,0 +1,108 @@
|
||||
# lightning.pytorch==2.4.0
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp repa_flatten_condit22_dit6_fixt_xl
|
||||
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: universal_flow
|
||||
name: *exp
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 4000000
|
||||
val_check_interval: 4000000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
plugins:
|
||||
- src.plugins.bd_env.BDEnvironment
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: true
|
||||
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 16
|
||||
hidden_size: &hidden_dim 1152
|
||||
num_blocks: 28
|
||||
num_cond_blocks: 22
|
||||
num_classes: 1000
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1000
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 8
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 2.0
|
||||
timeshift: 1.5
|
||||
state_refresh_rate: 1
|
||||
guidance_interval_min: 0.3
|
||||
guidance_interval_max: 1.0
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
last_step: 0.04
|
||||
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-4
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: imagenet256
|
||||
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
||||
train_image_size: 256
|
||||
train_batch_size: 16
|
||||
eval_max_num_instances: 50000
|
||||
pred_batch_size: 64
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1000
|
||||
latent_shape:
|
||||
- 4
|
||||
- 32
|
||||
- 32
|
||||
108
configs/repa_flatten_condit22_fixt_xl512.yaml
Normal file
108
configs/repa_flatten_condit22_fixt_xl512.yaml
Normal file
@@ -0,0 +1,108 @@
|
||||
# lightning.pytorch==2.4.0
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
|
||||
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: universal_flow
|
||||
name: *exp
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 4000000
|
||||
val_check_interval: 4000000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
plugins:
|
||||
- src.plugins.bd_env.BDEnvironment
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: true
|
||||
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 16
|
||||
hidden_size: &hidden_dim 1152
|
||||
num_blocks: 28
|
||||
num_cond_blocks: 22
|
||||
num_classes: 1000
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1000
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 8
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 3.0
|
||||
state_refresh_rate: 1
|
||||
guidance_interval_min: 0.3
|
||||
guidance_interval_max: 1.0
|
||||
timeshift: 1.0
|
||||
last_step: 0.04
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-4
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: imagenet512
|
||||
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
||||
train_image_size: 512
|
||||
train_batch_size: 16
|
||||
eval_max_num_instances: 50000
|
||||
pred_batch_size: 32
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1000
|
||||
latent_shape:
|
||||
- 4
|
||||
- 64
|
||||
- 64
|
||||
99
configs/repa_flatten_dit_fixt_large.yaml
Normal file
99
configs/repa_flatten_dit_fixt_large.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
# lightning.pytorch==2.4.0
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp repa_flatten_dit_fixt_large
|
||||
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: universal_flow
|
||||
name: *exp
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 400000
|
||||
val_check_interval: 100000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
plugins:
|
||||
- src.plugins.bd_env.BDEnvironment
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: true
|
||||
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 16
|
||||
hidden_size: &hidden_dim 1024
|
||||
num_blocks: 24
|
||||
num_classes: 1000
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1000
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 8
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 1.00
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-4
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: imagenet256
|
||||
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
||||
train_image_size: 256
|
||||
train_batch_size: 32
|
||||
eval_max_num_instances: 50000
|
||||
pred_batch_size: 64
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1000
|
||||
latent_shape:
|
||||
- 4
|
||||
- 32
|
||||
- 32
|
||||
99
configs/repa_flatten_dit_fixt_xl.yaml
Normal file
99
configs/repa_flatten_dit_fixt_xl.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
# lightning.pytorch==2.4.0
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp repa_flatten_dit_fixt_xl
|
||||
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: universal_flow
|
||||
name: *exp
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 400000
|
||||
val_check_interval: 100000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
plugins:
|
||||
- src.plugins.bd_env.BDEnvironment
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: true
|
||||
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 16
|
||||
hidden_size: &hidden_dim 1152
|
||||
num_blocks: 28
|
||||
num_classes: 1000
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1000
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 8
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 1.00
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-4
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: imagenet256
|
||||
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
||||
train_image_size: 256
|
||||
train_batch_size: 32
|
||||
eval_max_num_instances: 50000
|
||||
pred_batch_size: 64
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1000
|
||||
latent_shape:
|
||||
- 4
|
||||
- 32
|
||||
- 32
|
||||
86
main.py
Normal file
86
main.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
|
||||
import pylab as pl
|
||||
|
||||
from src.utils.patch_bugs import *
|
||||
|
||||
import os
|
||||
import torch
|
||||
from lightning import Trainer, LightningModule
|
||||
from src.lightning_data import DataModule
|
||||
from src.lightning_model import LightningModel
|
||||
from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger("lightning.pytorch")
|
||||
# log_path = os.path.join( f"log.txt")
|
||||
# logger.addHandler(logging.FileHandler(log_path))
|
||||
|
||||
class ReWriteRootSaveConfigCallback(SaveConfigCallback):
|
||||
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
|
||||
stamp = time.strftime('%y%m%d%H%M')
|
||||
file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml")
|
||||
self.parser.save(
|
||||
self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
|
||||
)
|
||||
|
||||
|
||||
class ReWriteRootDirCli(LightningCLI):
|
||||
def before_instantiate_classes(self) -> None:
|
||||
super().before_instantiate_classes()
|
||||
config_trainer = self._get(self.config, "trainer", default={})
|
||||
|
||||
# predict path & logger check
|
||||
if self.subcommand == "predict":
|
||||
config_trainer.logger = None
|
||||
|
||||
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
class TagsClass:
|
||||
def __init__(self, exp:str):
|
||||
...
|
||||
parser.add_class_arguments(TagsClass, nested_key="tags")
|
||||
|
||||
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
super().add_default_arguments_to_parser(parser)
|
||||
parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),)
|
||||
parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),)
|
||||
|
||||
def instantiate_trainer(self, **kwargs: Any) -> Trainer:
|
||||
config_trainer = self._get(self.config_init, "trainer", default={})
|
||||
default_root_dir = config_trainer.get("default_root_dir", None)
|
||||
|
||||
if default_root_dir is None:
|
||||
default_root_dir = os.path.join(os.getcwd(), "workdirs")
|
||||
|
||||
dirname = ""
|
||||
for v, k in self._get(self.config, "tags", default={}).items():
|
||||
dirname += f"{v}_{k}"
|
||||
default_root_dir = os.path.join(default_root_dir, dirname)
|
||||
is_resume = self._get(self.config_init, "ckpt_path", default=None)
|
||||
if os.path.exists(default_root_dir) and "debug" not in default_root_dir:
|
||||
if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume:
|
||||
raise FileExistsError(f"{default_root_dir} already exists")
|
||||
|
||||
config_trainer.default_root_dir = default_root_dir
|
||||
trainer = super().instantiate_trainer(**kwargs)
|
||||
if trainer.is_global_zero:
|
||||
os.makedirs(default_root_dir, exist_ok=True)
|
||||
return trainer
|
||||
|
||||
def instantiate_classes(self) -> None:
|
||||
torch_hub_dir = self._get(self.config, "torch_hub_dir")
|
||||
huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir")
|
||||
if huggingface_cache_dir is not None:
|
||||
os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir
|
||||
if torch_hub_dir is not None:
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
super().instantiate_classes()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
cli = ReWriteRootDirCli(LightningModel, DataModule,
|
||||
auto_configure_optimizers=False,
|
||||
save_config_callback=ReWriteRootSaveConfigCallback,
|
||||
save_config_kwargs={"overwrite": True})
|
||||
10
main.sh
Normal file
10
main.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
export NCCL_HOSTID=${MY_POD_NAME}
|
||||
export MASTER_ADDR=${ARNOLD_WORKER_0_HOST}
|
||||
export MASTER_PORT=${ARNOLD_WORKER_0_PORT}
|
||||
export NODE_RANK=${ARNOLD_ID}
|
||||
export NUM_NODES=${ARNOLD_WORKER_NUM}
|
||||
|
||||
python3 main.py fit -c $1 --trainer.num_nodes $NUM_NODES
|
||||
# for pid in $(ps -ef | grep "yaml" | grep -v "grep" | awk '{print $2}'); do kill -9 $pid; done
|
||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
lightning==2.5.0.post0
|
||||
omegaconf==2.3.0
|
||||
jsonargparse[signatures]>=4.27.7
|
||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/callbacks/__init__.py
Normal file
0
src/callbacks/__init__.py
Normal file
22
src/callbacks/grad.py
Normal file
22
src/callbacks/grad.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities import grad_norm
|
||||
from torch.optim import Optimizer
|
||||
|
||||
class GradientMonitor(pl.Callback):
|
||||
"""Logs the gradient norm"""
|
||||
|
||||
def __init__(self, norm_type: int = 2):
|
||||
norm_type = float(norm_type)
|
||||
if norm_type <= 0:
|
||||
raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
|
||||
self.norm_type = norm_type
|
||||
|
||||
def on_before_optimizer_step(
|
||||
self, trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
optimizer: Optimizer
|
||||
) -> None:
|
||||
norms = grad_norm(pl_module, norm_type=self.norm_type)
|
||||
max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
|
||||
pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})
|
||||
25
src/callbacks/model_checkpoint.py
Normal file
25
src/callbacks/model_checkpoint.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os.path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from soupsieve.util import lower
|
||||
|
||||
|
||||
class CheckpointHook(ModelCheckpoint):
|
||||
"""Save checkpoint with only the incremental part of the model"""
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
|
||||
self.dirpath = trainer.default_root_dir
|
||||
self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
|
||||
pl_module.strict_loading = False
|
||||
|
||||
def on_save_checkpoint(
|
||||
self, trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
checkpoint: Dict[str, Any]
|
||||
) -> None:
|
||||
del checkpoint["callbacks"]
|
||||
|
||||
# def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
|
||||
# if not "debug" in self.exception_ckpt_path:
|
||||
# trainer.save_checkpoint(self.exception_ckpt_path)
|
||||
105
src/callbacks/save_images.py
Normal file
105
src/callbacks/save_images.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import Callback
|
||||
|
||||
|
||||
import os.path
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from typing import Sequence, Any, Dict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
from lightning_utilities.core.rank_zero import rank_zero_info
|
||||
|
||||
def process_fn(image, path):
|
||||
Image.fromarray(image).save(path)
|
||||
|
||||
class SaveImagesHook(Callback):
|
||||
def __init__(self, save_dir="val", max_save_num=0, compressed=True):
|
||||
self.save_dir = save_dir
|
||||
self.max_save_num = max_save_num
|
||||
self.compressed = compressed
|
||||
|
||||
def save_start(self, target_dir):
|
||||
self.target_dir = target_dir
|
||||
self.executor_pool = ThreadPoolExecutor(max_workers=8)
|
||||
if not os.path.exists(self.target_dir):
|
||||
os.makedirs(self.target_dir, exist_ok=True)
|
||||
else:
|
||||
if os.listdir(target_dir) and "debug" not in str(target_dir):
|
||||
raise FileExistsError(f'{self.target_dir} already exists and not empty!')
|
||||
self.samples = []
|
||||
self._have_saved_num = 0
|
||||
rank_zero_info(f"Save images to {self.target_dir}")
|
||||
|
||||
def save_image(self, images, filenames):
|
||||
images = images.permute(0, 2, 3, 1).cpu().numpy()
|
||||
for sample, filename in zip(images, filenames):
|
||||
if isinstance(filename, Sequence):
|
||||
filename = filename[0]
|
||||
path = f'{self.target_dir}/{filename}'
|
||||
if self._have_saved_num >= self.max_save_num:
|
||||
break
|
||||
self.executor_pool.submit(process_fn, sample, path)
|
||||
self._have_saved_num += 1
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
samples: STEP_OUTPUT,
|
||||
batch: Any,
|
||||
) -> None:
|
||||
b, c, h, w = samples.shape
|
||||
xT, y, metadata = batch
|
||||
all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
|
||||
self.save_image(samples, metadata)
|
||||
if trainer.is_global_zero:
|
||||
all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
|
||||
self.samples.append(all_samples)
|
||||
|
||||
def save_end(self):
|
||||
if self.compressed and len(self.samples) > 0:
|
||||
samples = numpy.concatenate(self.samples)
|
||||
numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
|
||||
self.executor_pool.shutdown(wait=True)
|
||||
self.samples = []
|
||||
self.target_dir = None
|
||||
self._have_saved_num = 0
|
||||
self.executor_pool = None
|
||||
|
||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
|
||||
self.save_start(target_dir)
|
||||
|
||||
def on_validation_batch_end(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
outputs: STEP_OUTPUT,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
return self.process_batch(trainer, pl_module, outputs, batch)
|
||||
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.save_end()
|
||||
|
||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
|
||||
self.save_start(target_dir)
|
||||
|
||||
def on_predict_batch_end(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
pl_module: "pl.LightningModule",
|
||||
samples: Any,
|
||||
batch: Any,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
return self.process_batch(trainer, pl_module, samples, batch)
|
||||
|
||||
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.save_end()
|
||||
79
src/callbacks/simple_ema.py
Normal file
79
src/callbacks/simple_ema.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import threading
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import Callback
|
||||
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
||||
|
||||
from src.utils.copy import swap_tensors
|
||||
|
||||
class SimpleEMA(Callback):
|
||||
def __init__(self, net:nn.Module, ema_net:nn.Module,
|
||||
decay: float = 0.9999,
|
||||
every_n_steps: int = 1,
|
||||
eval_original_model:bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.every_n_steps = every_n_steps
|
||||
self.eval_original_model = eval_original_model
|
||||
self._stream = torch.cuda.Stream()
|
||||
|
||||
self.net_params = list(net.parameters())
|
||||
self.ema_params = list(ema_net.parameters())
|
||||
|
||||
def swap_model(self):
|
||||
for ema_p, p, in zip(self.ema_params, self.net_params):
|
||||
swap_tensors(ema_p, p)
|
||||
|
||||
def ema_step(self):
|
||||
@torch.no_grad()
|
||||
def ema_update(ema_model_tuple, current_model_tuple, decay):
|
||||
torch._foreach_mul_(ema_model_tuple, decay)
|
||||
torch._foreach_add_(
|
||||
ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
|
||||
)
|
||||
|
||||
if self._stream is not None:
|
||||
self._stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self._stream):
|
||||
ema_update(self.ema_params, self.net_params, self.decay)
|
||||
|
||||
|
||||
def on_train_batch_end(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
||||
) -> None:
|
||||
if trainer.global_step % self.every_n_steps == 0:
|
||||
self.ema_step()
|
||||
|
||||
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.eval_original_model:
|
||||
self.swap_model()
|
||||
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"decay": self.decay,
|
||||
"every_n_steps": self.every_n_steps,
|
||||
"eval_original_model": self.eval_original_model,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self.decay = state_dict["decay"]
|
||||
self.every_n_steps = state_dict["every_n_steps"]
|
||||
self.eval_original_model = state_dict["eval_original_model"]
|
||||
|
||||
1
src/data/__init__.py
Normal file
1
src/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
0
src/data/dataset/__init__.py
Normal file
0
src/data/dataset/__init__.py
Normal file
11
src/data/dataset/celeba.py
Normal file
11
src/data/dataset/celeba.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Callable
|
||||
from torchvision.datasets import CelebA
|
||||
|
||||
|
||||
class LocalDataset(CelebA):
|
||||
def __init__(self, root:str, ):
|
||||
super(LocalDataset, self).__init__(root, "train")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = super().__getitem__(idx)
|
||||
return data
|
||||
82
src/data/dataset/imagenet.py
Normal file
82
src/data/dataset/imagenet.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
from src.data.dataset.metric_dataset import CenterCrop
|
||||
|
||||
class LocalCachedDataset(ImageFolder):
|
||||
def __init__(self, root, resolution=256):
|
||||
super().__init__(root)
|
||||
self.transform = CenterCrop(resolution)
|
||||
self.cache_root = None
|
||||
|
||||
def load_latent(self, latent_path):
|
||||
pk_data = torch.load(latent_path)
|
||||
mean = pk_data['mean'].to(torch.float32)
|
||||
logvar = pk_data['logvar'].to(torch.float32)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
latent = mean + torch.randn_like(mean) * std
|
||||
return latent
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
image_path, target = self.samples[idx]
|
||||
latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
|
||||
|
||||
raw_image = Image.open(image_path).convert('RGB')
|
||||
raw_image = self.transform(raw_image)
|
||||
raw_image = to_tensor(raw_image)
|
||||
if self.cache_root is not None:
|
||||
latent = self.load_latent(latent_path)
|
||||
else:
|
||||
latent = raw_image
|
||||
return raw_image, latent, target
|
||||
|
||||
class ImageNet256(LocalCachedDataset):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 256)
|
||||
self.cache_root = root + "_256_latent"
|
||||
|
||||
class ImageNet512(LocalCachedDataset):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 512)
|
||||
self.cache_root = root + "_512_latent"
|
||||
|
||||
class PixImageNet(ImageFolder):
|
||||
def __init__(self, root, resolution=256):
|
||||
super().__init__(root)
|
||||
self.transform = CenterCrop(resolution)
|
||||
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
image_path, target = self.samples[idx]
|
||||
raw_image = Image.open(image_path).convert('RGB')
|
||||
raw_image = self.transform(raw_image)
|
||||
raw_image = to_tensor(raw_image)
|
||||
|
||||
normalized_image = self.normalize(raw_image)
|
||||
return raw_image, normalized_image, target
|
||||
|
||||
class PixImageNet64(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 64)
|
||||
|
||||
class PixImageNet128(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 128)
|
||||
|
||||
|
||||
class PixImageNet256(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 256)
|
||||
|
||||
class PixImageNet512(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 512)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
82
src/data/dataset/metric_dataset.py
Normal file
82
src/data/dataset/metric_dataset.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
|
||||
from PIL import Image
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
def test_collate(batch):
|
||||
return torch.stack(batch)
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, root, image_size=(224, 224)):
|
||||
self.root = pathlib.Path(root)
|
||||
images = []
|
||||
for ext in IMG_EXTENSIONS:
|
||||
images.extend(self.root.rglob(ext))
|
||||
random.shuffle(images)
|
||||
self.images = list(map(lambda x: str(x), images))
|
||||
self.transform = tvtf.Compose(
|
||||
[
|
||||
CenterCrop(image_size[0]),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
|
||||
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
|
||||
]
|
||||
)
|
||||
self.size = image_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
image = Image.open(self.images[idx])
|
||||
image = self.transform(image)
|
||||
except Exception as e:
|
||||
print(self.images[idx])
|
||||
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
|
||||
|
||||
# print(image)
|
||||
metadata = dict(
|
||||
path = self.images[idx],
|
||||
root = self.root,
|
||||
)
|
||||
return image #, metadata
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
41
src/data/dataset/randn.py
Normal file
41
src/data/dataset/randn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os.path
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
|
||||
class RandomNDataset(Dataset):
|
||||
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
|
||||
self.selected_classes = selected_classes
|
||||
if selected_classes is not None:
|
||||
num_classes = len(selected_classes)
|
||||
max_num_instances = 10*num_classes
|
||||
self.num_classes = num_classes
|
||||
self.seeds = seeds
|
||||
if seeds is not None:
|
||||
self.max_num_instances = len(seeds)*num_classes
|
||||
self.num_seeds = len(seeds)
|
||||
else:
|
||||
self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
|
||||
self.max_num_instances = self.num_seeds*num_classes
|
||||
|
||||
self.latent_shape = latent_shape
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
label = idx // self.num_seeds
|
||||
if self.selected_classes:
|
||||
label = self.selected_classes[label]
|
||||
seed = random.randint(0, 1<<31) #idx % self.num_seeds
|
||||
if self.seeds is not None:
|
||||
seed = self.seeds[idx % self.num_seeds]
|
||||
|
||||
# cls_dir = os.path.join(self.root, f"{label}")
|
||||
filename = f"{label}_{seed}.png",
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
|
||||
return latent, label, filename
|
||||
def __len__(self):
|
||||
return self.max_num_instances
|
||||
145
src/data/var_training.py
Normal file
145
src/data/var_training.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import copy
|
||||
import torchvision.transforms.functional as tvtf
|
||||
from src.models.vae import uint82fp
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, width, height):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = max(width / pil_image.size[0], height / pil_image.size[1])
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
arr = np.array(pil_image)
|
||||
crop_y = random.randint(0, (arr.shape[0] - height))
|
||||
crop_x = random.randint(0, (arr.shape[1] - width))
|
||||
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
|
||||
|
||||
def process_fn(width, height, data, hflip=0.5):
|
||||
image, label = data
|
||||
if random.uniform(0, 1) > hflip: # hflip
|
||||
image = tvtf.hflip(image)
|
||||
image = center_crop_arr(image, width, height) # crop
|
||||
image = np.array(image).transpose(2, 0, 1)
|
||||
return image, label
|
||||
|
||||
class VARCandidate:
|
||||
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
|
||||
self.aspect_ratio = aspect_ratio
|
||||
self.width = int(width)
|
||||
self.height = int(height)
|
||||
self.buffer = buffer
|
||||
self.max_buffer_size = max_buffer_size
|
||||
|
||||
def add_sample(self, data):
|
||||
self.buffer.append(data)
|
||||
self.buffer = self.buffer[-self.max_buffer_size:]
|
||||
|
||||
def ready(self, batch_size):
|
||||
return len(self.buffer) >= batch_size
|
||||
|
||||
def get_batch(self, batch_size):
|
||||
batch = self.buffer[:batch_size]
|
||||
self.buffer = self.buffer[batch_size:]
|
||||
batch = [copy.deepcopy(b.result()) for b in batch]
|
||||
x, y = zip(*batch)
|
||||
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
|
||||
x = list(map(uint82fp, x))
|
||||
return x, y
|
||||
|
||||
class VARTransformEngine:
|
||||
def __init__(self,
|
||||
base_image_size,
|
||||
num_aspect_ratios,
|
||||
min_aspect_ratio,
|
||||
max_aspect_ratio,
|
||||
num_workers = 8,
|
||||
):
|
||||
self.base_image_size = base_image_size
|
||||
self.num_aspect_ratios = num_aspect_ratios
|
||||
self.min_aspect_ratio = min_aspect_ratio
|
||||
self.max_aspect_ratio = max_aspect_ratio
|
||||
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
|
||||
self.aspect_ratios = self.aspect_ratios.tolist()
|
||||
self.candidates_pool = []
|
||||
for i in range(self.num_aspect_ratios):
|
||||
candidate = VARCandidate(
|
||||
aspect_ratio=self.aspect_ratios[i],
|
||||
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
|
||||
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
|
||||
buffer=[],
|
||||
max_buffer_size=1024
|
||||
)
|
||||
self.candidates_pool.append(candidate)
|
||||
self.default_candidate = VARCandidate(
|
||||
aspect_ratio=1.0,
|
||||
width=self.base_image_size,
|
||||
height=self.base_image_size,
|
||||
buffer=[],
|
||||
max_buffer_size=1024,
|
||||
)
|
||||
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
|
||||
self._prefill_count = 100
|
||||
|
||||
def find_candidate(self, data):
|
||||
image = data[0]
|
||||
aspect_ratio = image.size[0] / image.size[1]
|
||||
min_distance = 1000000
|
||||
min_candidate = None
|
||||
for candidate in self.candidates_pool:
|
||||
dis = abs(aspect_ratio - candidate.aspect_ratio)
|
||||
if dis < min_distance:
|
||||
min_distance = dis
|
||||
min_candidate = candidate
|
||||
return min_candidate
|
||||
|
||||
|
||||
def __call__(self, batch_data):
|
||||
self._prefill_count -= 1
|
||||
if isinstance(batch_data[0], torch.Tensor):
|
||||
batch_data[0] = batch_data[0].unbind(0)
|
||||
|
||||
batch_data = list(zip(*batch_data))
|
||||
for data in batch_data:
|
||||
candidate = self.find_candidate(data)
|
||||
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
|
||||
candidate.add_sample(future)
|
||||
if self._prefill_count >= 0:
|
||||
future = self.executor_pool.submit(process_fn,
|
||||
self.default_candidate.width,
|
||||
self.default_candidate.height,
|
||||
data)
|
||||
self.default_candidate.add_sample(future)
|
||||
|
||||
batch_size = len(batch_data)
|
||||
random.shuffle(self.candidates_pool)
|
||||
for candidate in self.candidates_pool:
|
||||
if candidate.ready(batch_size=batch_size):
|
||||
return candidate.get_batch(batch_size=batch_size)
|
||||
|
||||
# fallback to default 256
|
||||
for data in batch_data:
|
||||
future = self.executor_pool.submit(process_fn,
|
||||
self.default_candidate.width,
|
||||
self.default_candidate.height,
|
||||
data)
|
||||
self.default_candidate.add_sample(future)
|
||||
return self.default_candidate.get_batch(batch_size=batch_size)
|
||||
0
src/diffusion/__init__.py
Normal file
0
src/diffusion/__init__.py
Normal file
60
src/diffusion/base/guidance.py
Normal file
60
src/diffusion/base/guidance.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
def simple_guidance_fn(out, cfg):
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = uncondition + cfg * (condtion - uncondition)
|
||||
return out
|
||||
|
||||
def c3_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
|
||||
return out
|
||||
|
||||
def c4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p05_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p10_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p15_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def c4_p20_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condition = out.chunk(2, dim=0)
|
||||
out = condition
|
||||
out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
|
||||
out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
|
||||
def p4_guidance_fn(out, cfg):
|
||||
# guidance function in DiT/SiT, seems like a bug not a feature?
|
||||
uncondition, condtion = out.chunk(2, dim=0)
|
||||
out = condtion
|
||||
out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
|
||||
return out
|
||||
31
src/diffusion/base/sampling.py
Normal file
31
src/diffusion/base/sampling.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Callable
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
class BaseSampler(nn.Module):
|
||||
def __init__(self,
|
||||
scheduler: BaseScheduler = None,
|
||||
guidance_fn: Callable = None,
|
||||
num_steps: int = 250,
|
||||
guidance: Union[float, List[float]] = 1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super(BaseSampler, self).__init__()
|
||||
self.num_steps = num_steps
|
||||
self.guidance = guidance
|
||||
self.guidance_fn = guidance_fn
|
||||
self.scheduler = scheduler
|
||||
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, noise, condition, uncondition):
|
||||
denoised = self._impl_sampling(net, noise, condition, uncondition)
|
||||
return denoised
|
||||
|
||||
|
||||
32
src/diffusion/base/scheduling.py
Normal file
32
src/diffusion/base/scheduling.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
class BaseScheduler:
|
||||
def alpha(self, t) -> Tensor:
|
||||
...
|
||||
def sigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
...
|
||||
def dsigma(self, t) -> Tensor:
|
||||
...
|
||||
|
||||
def dalpha_over_alpha(self, t) -> Tensor:
|
||||
return self.dalpha(t) / self.alpha(t)
|
||||
|
||||
def dsigma_mul_sigma(self, t) -> Tensor:
|
||||
return self.dsigma(t)*self.sigma(t)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dalpha/(alpha + 1e-6)
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
alpha, sigma = self.alpha(t), self.sigma(t)
|
||||
dalpha, dsigma = self.dalpha(t), self.dsigma(t)
|
||||
return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
|
||||
|
||||
def w(self, t):
|
||||
return self.sigma(t)
|
||||
29
src/diffusion/base/training.py
Normal file
29
src/diffusion/base/training.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseTrainer(nn.Module):
|
||||
def __init__(self,
|
||||
null_condition_p=0.1,
|
||||
log_var=False,
|
||||
):
|
||||
super(BaseTrainer, self).__init__()
|
||||
self.null_condition_p = null_condition_p
|
||||
self.log_var = log_var
|
||||
|
||||
def preproprocess(self, raw_iamges, x, condition, uncondition):
|
||||
bsz = x.shape[0]
|
||||
if self.null_condition_p > 0:
|
||||
mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
|
||||
mask = mask.expand_as(condition)
|
||||
condition[mask] = uncondition[mask]
|
||||
return raw_iamges, x, condition
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
|
||||
raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
|
||||
return self._impl_trainstep(net, ema_net, raw_images, x, condition)
|
||||
|
||||
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
40
src/diffusion/ddpm/ddim_sampling.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DDIMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_num_steps=1000,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.train_num_steps = train_num_steps
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
|
||||
steps = torch.flip(steps, dims=[0])
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
t_next = t_next.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha = self.scheduler.alpha(t_cur)
|
||||
sigma_next = self.scheduler.sigma(t_next)
|
||||
alpha_next = self.scheduler.alpha(t_next)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
t = t_cur.repeat(2)
|
||||
out = net(cfg_x, t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
x0 = (x - sigma * out) / alpha
|
||||
x = alpha_next * x0 + sigma_next * out
|
||||
return x0
|
||||
102
src/diffusion/ddpm/scheduling.py
Normal file
102
src/diffusion/ddpm/scheduling.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class DDPMScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.0001,
|
||||
beta_max=0.02,
|
||||
num_steps=1000,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_max = beta_max
|
||||
self.num_steps = num_steps
|
||||
|
||||
self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
|
||||
self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
|
||||
self.sigmas_table = 1-self.alphas_table
|
||||
|
||||
|
||||
def beta(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.betas_table[t].view(-1, 1, 1, 1)
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = t.to(torch.long)
|
||||
return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
|
||||
class VPScheduler(BaseScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
beta_min=0.1,
|
||||
beta_max=20,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta_min = beta_min
|
||||
self.beta_d = beta_max - beta_min
|
||||
def beta(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
|
||||
|
||||
def sigma(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
|
||||
return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
|
||||
|
||||
def dsigma(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha_over_alpha(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dsigma_mul_sigma(self, t) ->Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def dalpha(self, t) -> Tensor:
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def alpha(self, t) -> Tensor:
|
||||
t = torch.clamp(t, min=1e-3, max=1)
|
||||
inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
|
||||
return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
|
||||
|
||||
def drift_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def diffuse_coefficient(self, t):
|
||||
raise NotImplementedError("wrong usage")
|
||||
|
||||
def w(self, t):
|
||||
return self.diffuse_coefficient(t)
|
||||
|
||||
|
||||
|
||||
83
src/diffusion/ddpm/training.py
Normal file
83
src/diffusion/ddpm/training.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class VPTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t*self.train_max_t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - noise)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class DDPMTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn: Callable = constant,
|
||||
train_max_t=1000,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
t = torch.randint(0, self.train_max_t, (batch_size,))
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
out = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight * (out - noise) ** 2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/ddpm/vp_sampling.py
Normal file
59
src/diffusion/ddpm/vp_sampling.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
from typing import Callable
|
||||
|
||||
def ode_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
|
||||
|
||||
def sde_step_fn(x, eps, beta, sigma, dt):
|
||||
return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VPEulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
train_max_t=1000,
|
||||
guidance_fn: Callable = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.guidance_fn = guidance_fn
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.train_max_t = train_max_t
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
batch_size = noise.shape[0]
|
||||
steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
|
||||
steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
beta = self.scheduler.beta(t_cur)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
|
||||
eps = self.guidance_fn(out, self.guidance)
|
||||
if i < self.num_steps -1 :
|
||||
x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
|
||||
x = self.step_fn(x, eps, beta, sigma, dt)
|
||||
else:
|
||||
x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
|
||||
return x
|
||||
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
107
src/diffusion/flow_matching/adam_sampling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import math
|
||||
from src.diffusion.base.sampling import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.pre_integral import *
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def t2snr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2snr(t) for t in t]
|
||||
t = max(t, 1e-8)
|
||||
return (t/(1-t + 1e-8))
|
||||
|
||||
def t2logsnr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2logsnr(t) for t in t]
|
||||
t = max(t, 1e-3)
|
||||
return math.log(t/(1-t + 1e-3))
|
||||
|
||||
def t2isnr(t):
|
||||
return 1/t2snr(t)
|
||||
|
||||
def nop(t):
|
||||
return t
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdamLMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
|
||||
def _reparameterize_coeffs(self):
|
||||
solver_coeffs = [[] for _ in range(self.num_steps)]
|
||||
for i in range(0, self.num_steps):
|
||||
pre_vs = [1.0, ]*(i+1)
|
||||
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
||||
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
||||
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
||||
|
||||
order_annealing = self.order #self.num_steps - i
|
||||
order = min(self.order, i + 1, order_annealing)
|
||||
|
||||
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
||||
solver_coeffs[i] = coeffs
|
||||
self.solver_coeffs = solver_coeffs
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
pred_trajectory = []
|
||||
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
||||
timedeltas = self.timedeltas
|
||||
solver_coeffs = self.solver_coeffs
|
||||
for i in range(self.num_steps):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
for j in range(order):
|
||||
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
||||
v = out
|
||||
dt = timedeltas[i]
|
||||
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
||||
x = self.step_fn(x, v, dt, s=0, w=0)
|
||||
t_cur += dt
|
||||
return x
|
||||
179
src/diffusion/flow_matching/sampling.py
Normal file
179
src/diffusion/flow_matching/sampling.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def sde_mean_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt + s * w * dt
|
||||
|
||||
def sde_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
||||
|
||||
def sde_preserve_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
|
||||
|
||||
class HeunSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler = None,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
exact_henu=False,
|
||||
timeshift=1.0,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.scheduler = scheduler
|
||||
self.exact_henu = exact_henu
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Henu sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
v_hat, s_hat = 0.0, 0.0
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
t_hat = t_next
|
||||
t_hat = t_hat.repeat(batch_size)
|
||||
sigma_hat = self.scheduler.sigma(t_hat)
|
||||
alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
|
||||
dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
|
||||
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
if i == 0 or self.exact_henu:
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t_cur = t_cur.repeat(2)
|
||||
out = net(cfg_x, cfg_t_cur, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v = out
|
||||
s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
|
||||
else:
|
||||
v = v_hat
|
||||
s = s_hat
|
||||
x_hat = self.step_fn(x, v, dt, s=s, w=w)
|
||||
# henu correct
|
||||
if i < self.num_steps -1:
|
||||
cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
|
||||
cfg_t_hat = t_hat.repeat(2)
|
||||
out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
v_hat = out
|
||||
s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
|
||||
v = (v + v_hat) / 2
|
||||
s = (s + s_hat) / 2
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
39
src/diffusion/flow_matching/scheduling.py
Normal file
39
src/diffusion/flow_matching/scheduling.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class LinearScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return (t).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return (1-t).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
||||
|
||||
# SoTA for ImageNet!
|
||||
class GVPScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def w(self, t):
|
||||
return torch.sin(t)**2
|
||||
|
||||
class ConstScheduler(BaseScheduler):
|
||||
def w(self, t):
|
||||
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
||||
|
||||
from src.diffusion.ddpm.scheduling import VPScheduler
|
||||
class VPBetaScheduler(VPScheduler):
|
||||
def w(self, t):
|
||||
return self.beta(t).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
|
||||
55
src/diffusion/flow_matching/training.py
Normal file
55
src/diffusion/flow_matching/training.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class FlowMatchingTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
59
src/diffusion/flow_matching/training_cos.py
Normal file
59
src/diffusion/flow_matching/training_cos.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class COSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
68
src/diffusion/flow_matching/training_pyramid.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class PyramidTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
|
||||
output_pyramid = []
|
||||
def feature_hook(module, input, output):
|
||||
output_pyramid.extend(output)
|
||||
handle = net.decoder.register_forward_hook(feature_hook)
|
||||
net(x_t, t, y)
|
||||
handle.remove()
|
||||
|
||||
loss = 0.0
|
||||
out_dict = dict()
|
||||
|
||||
cur_v_t = v_t
|
||||
for i in range(len(output_pyramid)):
|
||||
cur_out = output_pyramid[i]
|
||||
loss_i = (cur_v_t - cur_out) ** 2
|
||||
loss += loss_i.mean()
|
||||
out_dict["loss_{}".format(i)] = loss_i.mean()
|
||||
cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
out_dict["loss"] = loss
|
||||
return out_dict
|
||||
|
||||
142
src/diffusion/flow_matching/training_repa.py
Normal file
142
src/diffusion/flow_matching/training_repa.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
152
src/diffusion/flow_matching/training_repa_mask.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
mask_ratio=0.0,
|
||||
mask_patch_size=2,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.mask_ratio = mask_ratio
|
||||
self.mask_patch_size = mask_patch_size
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
|
||||
patch_mask = (patch_mask < self.mask_ratio).float()
|
||||
mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
|
||||
masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
|
||||
|
||||
x_t = alpha*masked_x + sigma*noise
|
||||
v_t = dalpha*x + dsigma*noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
v_t_out, x0_out = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
|
||||
mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
mask_loss=mask_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
143
src/diffusion/pre_integral.py
Normal file
143
src/diffusion/pre_integral.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
|
||||
# lagrange interpolation
|
||||
def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 1
|
||||
Args:
|
||||
t1: timestepx
|
||||
v1: value field at t1
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1 = (int_t_end-int_t_start)
|
||||
return int1*v1, (int1/int1, )
|
||||
|
||||
def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 2
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
|
||||
int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
|
||||
int_sum = int1+int2
|
||||
return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
|
||||
|
||||
def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 3
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
t3: timestepz
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
v3: value field at t3
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1_denom = (t1-t2)*(t1-t3)
|
||||
int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
|
||||
int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
|
||||
int1 = (int1_end - int1_start)/int1_denom
|
||||
int2_denom = (t2-t1)*(t2-t3)
|
||||
int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
|
||||
int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
|
||||
int2 = (int2_end - int2_start)/int2_denom
|
||||
int3_denom = (t3-t1)*(t3-t2)
|
||||
int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
|
||||
int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
|
||||
int3 = (int3_end - int3_start)/int3_denom
|
||||
int_sum = int1+int2+int3
|
||||
return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
|
||||
|
||||
def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation of order 4
|
||||
Args:
|
||||
t1: timestepx
|
||||
t2: timestepy
|
||||
t3: timestepz
|
||||
t4: timestepw
|
||||
v1: value field at t1
|
||||
v2: value field at t2
|
||||
v3: value field at t3
|
||||
v4: value field at t4
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
|
||||
int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
|
||||
int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
|
||||
int1 = (int1_end - int1_start)/int1_denom
|
||||
int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
|
||||
int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
|
||||
int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
|
||||
int2 = (int2_end - int2_start)/int2_denom
|
||||
int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
|
||||
int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
|
||||
int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
|
||||
int3 = (int3_end - int3_start)/int3_denom
|
||||
int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
|
||||
int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
|
||||
int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
|
||||
int4 = (int4_end - int4_start)/int4_denom
|
||||
int_sum = int1+int2+int3+int4
|
||||
return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
|
||||
|
||||
|
||||
def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
|
||||
'''
|
||||
lagrange interpolation
|
||||
Args:
|
||||
order: order of interpolation
|
||||
pre_vs: value field at pre_ts
|
||||
pre_ts: timesteps
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
order = min(order, len(pre_vs), len(pre_ts))
|
||||
if order == 1:
|
||||
return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 2:
|
||||
return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 3:
|
||||
return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
elif order == 4:
|
||||
return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
|
||||
else:
|
||||
raise ValueError('Invalid order')
|
||||
|
||||
|
||||
def polynomial_integral(coeffs, int_t_start, int_t_end):
|
||||
'''
|
||||
polynomial integral
|
||||
Args:
|
||||
coeffs: coefficients of the polynomial
|
||||
int_t_start: intergation start time
|
||||
int_t_end: intergation end time
|
||||
Returns:
|
||||
integrated value
|
||||
'''
|
||||
orders = len(coeffs)
|
||||
int_val = 0
|
||||
for o in range(orders):
|
||||
int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
|
||||
return int_val
|
||||
|
||||
112
src/diffusion/stateful_flow_matching/adam_sampling.py
Normal file
112
src/diffusion/stateful_flow_matching/adam_sampling.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import math
|
||||
from src.diffusion.base.sampling import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.pre_integral import *
|
||||
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def t2snr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return (t.clip(min=1e-8)/(1-t + 1e-8))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2snr(t) for t in t]
|
||||
t = max(t, 1e-8)
|
||||
return (t/(1-t + 1e-8))
|
||||
|
||||
def t2logsnr(t):
|
||||
if isinstance(t, torch.Tensor):
|
||||
return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
|
||||
if isinstance(t, List) or isinstance(t, Tuple):
|
||||
return [t2logsnr(t) for t in t]
|
||||
t = max(t, 1e-3)
|
||||
return math.log(t/(1-t + 1e-3))
|
||||
|
||||
def t2isnr(t):
|
||||
return 1/t2snr(t)
|
||||
|
||||
def nop(t):
|
||||
return t
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AdamLMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
order: int = 2,
|
||||
timeshift: float = 1.0,
|
||||
state_refresh_rate: int = 1,
|
||||
lms_transform_fn: Callable = nop,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
self.order = order
|
||||
self.lms_transform_fn = lms_transform_fn
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, timeshift)
|
||||
self.timedeltas = timesteps[1:] - self.timesteps[:-1]
|
||||
self._reparameterize_coeffs()
|
||||
|
||||
def _reparameterize_coeffs(self):
|
||||
solver_coeffs = [[] for _ in range(self.num_steps)]
|
||||
for i in range(0, self.num_steps):
|
||||
pre_vs = [1.0, ]*(i+1)
|
||||
pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
|
||||
int_t_start = self.lms_transform_fn(self.timesteps[i])
|
||||
int_t_end = self.lms_transform_fn(self.timesteps[i+1])
|
||||
|
||||
order_annealing = self.order #self.num_steps - i
|
||||
order = min(self.order, i + 1, order_annealing)
|
||||
|
||||
_, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
|
||||
solver_coeffs[i] = coeffs
|
||||
self.solver_coeffs = solver_coeffs
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = x0 = noise
|
||||
state = None
|
||||
pred_trajectory = []
|
||||
t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
|
||||
timedeltas = self.timedeltas
|
||||
solver_coeffs = self.solver_coeffs
|
||||
for i in range(self.num_steps):
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
out = self.guidance_fn(out, self.guidances[i])
|
||||
pred_trajectory.append(out)
|
||||
out = torch.zeros_like(out)
|
||||
order = len(self.solver_coeffs[i])
|
||||
for j in range(order):
|
||||
out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
|
||||
v = out
|
||||
dt = timedeltas[i]
|
||||
x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
|
||||
x = self.step_fn(x, v, dt, s=0, w=0)
|
||||
t_cur += dt
|
||||
return x
|
||||
122
src/diffusion/stateful_flow_matching/bak/training_adv.py
Normal file
122
src/diffusion/stateful_flow_matching/bak/training_adv.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
adv_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
adv_feature.append(output)
|
||||
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
|
||||
real_feature = adv_feature.pop()
|
||||
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
|
||||
fake_feature = adv_feature.pop()
|
||||
handle.remove()
|
||||
|
||||
|
||||
real_score_gan = self.dis_head(real_feature.detach())
|
||||
fake_score_gan = self.dis_head(fake_feature.detach())
|
||||
fake_score = self.dis_head(fake_feature)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
127
src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
Normal file
127
src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
lpips_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.lpips_weight = lpips_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
with torch.no_grad():
|
||||
_, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
|
||||
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
|
||||
|
||||
real_score_gan = self.dis_head(real_features[-1].detach())
|
||||
fake_score_gan = self.dis_head(fake_features[-1].detach())
|
||||
fake_score = self.dis_head(fake_features[-1])
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
159
src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
Normal file
159
src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class MaskREPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
mask_groups=4,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.mask_groups = mask_groups
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
|
||||
mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
|
||||
random_seq = torch.randperm(length, device=device)
|
||||
for i in range(groups):
|
||||
group_start = (length+groups-1)//groups*i
|
||||
group_end = (length+groups-1)//groups*(i+1)
|
||||
group_random_seq = random_seq[group_start:group_end]
|
||||
y, x = torch.meshgrid(group_random_seq, group_random_seq)
|
||||
mask[:, y, x] = True
|
||||
return mask
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
mask_groups = random.randint(1, self.mask_groups)
|
||||
mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
|
||||
out, _ = net(x_t, t, y, mask=mask)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
179
src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
Normal file
179
src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, mul=1000):
|
||||
t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class BatchNormWithTimeEmbedding(nn.Module):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
# self.bn = nn.BatchNorm2d(num_features, affine=False)
|
||||
self.bn = nn.GroupNorm(16, num_features, affine=False)
|
||||
# self.bn = nn.SyncBatchNorm(num_features, affine=False)
|
||||
self.embedder = TimestepEmbedder(num_features * 2)
|
||||
# nn.init.zeros_(self.embedder.mlp[-1].weight)
|
||||
nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
|
||||
nn.init.zeros_(self.embedder.mlp[-1].bias)
|
||||
|
||||
def forward(self, x, t):
|
||||
embed = self.embedder(t)
|
||||
embed = embed[:, :, None, None]
|
||||
gamma, beta = embed.chunk(2, dim=1)
|
||||
gamma = 1.0 + gamma
|
||||
normed = self.bn(x)
|
||||
out = normed * gamma + beta
|
||||
return out
|
||||
|
||||
class DisBlock(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
|
||||
)
|
||||
self.norm = BatchNormWithTimeEmbedding(hidden_size)
|
||||
self.act = nn.SiLU()
|
||||
def forward(self, x, t):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x, t)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, num_blocks, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(num_blocks):
|
||||
self.blocks.append(
|
||||
DisBlock(
|
||||
in_channels=in_channels,
|
||||
hidden_size=hidden_size,
|
||||
)
|
||||
)
|
||||
in_channels = hidden_size
|
||||
self.classifier = nn.Conv2d(
|
||||
kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
|
||||
)
|
||||
def forward(self, feature, t):
|
||||
B, C, H, W = feature.shape
|
||||
for block in self.blocks:
|
||||
feature = block(feature, t)
|
||||
out = self.classifier(feature).view(B, -1)
|
||||
out = out.sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_blocks=3,
|
||||
adv_in_channels=3,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
|
||||
self.discriminator = Discriminator(
|
||||
num_blocks=adv_blocks,
|
||||
in_channels=adv_in_channels*2,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = x_t + sigma * out
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
real_feature = torch.cat([x_t, x], dim=1)
|
||||
fake_feature = torch.cat([x_t, pred_x0], dim=1)
|
||||
|
||||
real_score_gan = self.discriminator(real_feature.detach(), t)
|
||||
fake_score_gan = self.discriminator(fake_feature.detach(), t)
|
||||
fake_score = self.discriminator(fake_feature, t)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
154
src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
Normal file
154
src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPAJiTTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
jit_deltas=0.01,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.jit_deltas = jit_deltas
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
|
||||
t2 = torch.clip(t2, 0, 1)
|
||||
alpha = self.scheduler.alpha(t2)
|
||||
dalpha = self.scheduler.dalpha(t2)
|
||||
sigma = self.scheduler.sigma(t2)
|
||||
dsigma = self.scheduler.dsigma(t2)
|
||||
x_t2 = alpha * x + noise * sigma
|
||||
v_t2 = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
_, s = net(x_t, t, y, only_s=True)
|
||||
out, _ = net(x_t2, t2, y, s)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t2)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class SelfConsistentTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
lpips_encoder_layer=4,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_encoder_layer = lpips_encoder_layer
|
||||
self.lpips_weight = lpips_weight
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
real_features = []
|
||||
def forward_hook(net, input, output):
|
||||
real_features.append(output)
|
||||
handles = []
|
||||
for i in range(self.lpips_encoder_layer):
|
||||
handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
|
||||
handles.append(handle)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
for handle in handles:
|
||||
handle.remove()
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + noise * sigma
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
_, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class SelfLPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
lpips_encoder_layer=4,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_encoder_layer = lpips_encoder_layer
|
||||
self.lpips_weight = lpips_weight
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + noise * sigma
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
with torch.no_grad():
|
||||
_, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
|
||||
_, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
|
||||
|
||||
|
||||
lpips_loss = []
|
||||
for r, f in zip(real_features, fake_features):
|
||||
r = torch.nn.functional.normalize(r, dim=-1)
|
||||
f = torch.nn.functional.normalize(f, dim=-1)
|
||||
lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
|
||||
lpips_loss = sum(lpips_loss)
|
||||
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips_loss.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
|
||||
)
|
||||
return out
|
||||
78
src/diffusion/stateful_flow_matching/cm_sampling.py
Normal file
78
src/diffusion/stateful_flow_matching/cm_sampling.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CMSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
last_step=None,
|
||||
step_fn=None,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.last_step = last_step
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
cfg_t = t_cur.repeat(batch_size*2)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
|
||||
x0 = x + v * (1-t_cur)
|
||||
alpha_next = self.scheduler.alpha(t_next)
|
||||
sigma_next = self.scheduler.sigma(t_next)
|
||||
x = alpha_next * x0 + sigma_next * torch.randn_like(x)
|
||||
# print(alpha_next, sigma_next)
|
||||
return x
|
||||
103
src/diffusion/stateful_flow_matching/sampling.py
Normal file
103
src/diffusion/stateful_flow_matching/sampling.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
def sde_mean_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt + s * w * dt
|
||||
|
||||
def sde_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
|
||||
|
||||
def sde_preserve_step_fn(x, v, dt, s, w):
|
||||
return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
sigma = self.scheduler.sigma(t_cur)
|
||||
dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
|
||||
dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
|
||||
if self.w_scheduler:
|
||||
w = self.w_scheduler.w(t_cur)
|
||||
else:
|
||||
w = 0.0
|
||||
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i % self.state_refresh_rate == 0:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=s, w=w)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=s, w=w)
|
||||
return x
|
||||
39
src/diffusion/stateful_flow_matching/scheduling.py
Normal file
39
src/diffusion/stateful_flow_matching/scheduling.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import math
|
||||
import torch
|
||||
from src.diffusion.base.scheduling import *
|
||||
|
||||
|
||||
class LinearScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return (t).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return (1-t).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
|
||||
|
||||
# SoTA for ImageNet!
|
||||
class GVPScheduler(BaseScheduler):
|
||||
def alpha(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def sigma(self, t) -> Tensor:
|
||||
return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dalpha(self, t) -> Tensor:
|
||||
return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def dsigma(self, t) -> Tensor:
|
||||
return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
|
||||
def w(self, t):
|
||||
return torch.sin(t)**2
|
||||
|
||||
class ConstScheduler(BaseScheduler):
|
||||
def w(self, t):
|
||||
return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
|
||||
|
||||
from src.diffusion.ddpm.scheduling import VPScheduler
|
||||
class VPBetaScheduler(VPScheduler):
|
||||
def w(self, t):
|
||||
return self.beta(t).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
|
||||
149
src/diffusion/stateful_flow_matching/sharing_sampling.py
Normal file
149
src/diffusion/stateful_flow_matching/sharing_sampling.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
|
||||
from src.diffusion.base.guidance import *
|
||||
from src.diffusion.base.scheduling import *
|
||||
from src.diffusion.base.sampling import *
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def shift_respace_fn(t, shift=3.0):
|
||||
return t / (t + (1 - t) * shift)
|
||||
|
||||
def ode_step_fn(x, v, dt, s, w):
|
||||
return x + v * dt
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EulerSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
w_scheduler: BaseScheduler = None,
|
||||
timeshift=1.0,
|
||||
guidance_interval_min: float = 0.0,
|
||||
guidance_interval_max: float = 1.0,
|
||||
state_refresh_rate=1,
|
||||
step_fn: Callable = ode_step_fn,
|
||||
last_step=None,
|
||||
last_step_fn: Callable = ode_step_fn,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.step_fn = step_fn
|
||||
self.last_step = last_step
|
||||
self.last_step_fn = last_step_fn
|
||||
self.w_scheduler = w_scheduler
|
||||
self.timeshift = timeshift
|
||||
self.state_refresh_rate = state_refresh_rate
|
||||
self.guidance_interval_min = guidance_interval_min
|
||||
self.guidance_interval_max = guidance_interval_max
|
||||
|
||||
if self.last_step is None or self.num_steps == 1:
|
||||
self.last_step = 1.0 / self.num_steps
|
||||
|
||||
timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
|
||||
timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
|
||||
self.timesteps = shift_respace_fn(timesteps, self.timeshift)
|
||||
|
||||
assert self.last_step > 0.0
|
||||
assert self.scheduler is not None
|
||||
assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
|
||||
if self.w_scheduler is not None:
|
||||
if self.step_fn == ode_step_fn:
|
||||
logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
|
||||
|
||||
# init recompute
|
||||
self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
|
||||
self.recompute_timesteps = list(range(self.num_steps))
|
||||
|
||||
def sharing_dp(self, net, noise, condition, uncondition):
|
||||
_, C, H, W = noise.shape
|
||||
B = 8
|
||||
template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
|
||||
template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
|
||||
template_uncondition = torch.full((B, ), 1000, device=condition.device)
|
||||
_, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
|
||||
states = torch.stack(state_list)
|
||||
N, B, L, C = states.shape
|
||||
states = states.view(N, B*L, C )
|
||||
states = states.permute(1, 0, 2)
|
||||
states = torch.nn.functional.normalize(states, dim=-1)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float64):
|
||||
sim = torch.bmm(states, states.transpose(1, 2))
|
||||
sim = torch.mean(sim, dim=0).cpu()
|
||||
error_map = (1-sim).tolist()
|
||||
|
||||
# init cum-error
|
||||
for i in range(1, self.num_steps):
|
||||
for j in range(0, i):
|
||||
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
|
||||
|
||||
# init dp and force 0 start
|
||||
C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
||||
P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
|
||||
for i in range(1, self.num_steps+1):
|
||||
C[1][i] = error_map[i - 1][0]
|
||||
P[1][i] = 0
|
||||
|
||||
# dp state
|
||||
for step in range(2, self.num_recompute_timesteps+1):
|
||||
for i in range(step, self.num_steps+1):
|
||||
min_value = 99999
|
||||
min_index = -1
|
||||
for j in range(step-1, i):
|
||||
value = C[step-1][j] + error_map[i-1][j]
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
min_index = j
|
||||
C[step][i] = min_value
|
||||
P[step][i] = min_index
|
||||
|
||||
# trace back
|
||||
timesteps = [self.num_steps,]
|
||||
for i in range(self.num_recompute_timesteps, 0, -1):
|
||||
idx = timesteps[-1]
|
||||
timesteps.append(P[i][idx])
|
||||
timesteps.reverse()
|
||||
|
||||
print("recompute timesteps solved by DP: ", timesteps)
|
||||
return timesteps[:-1]
|
||||
|
||||
def _impl_sampling(self, net, noise, condition, uncondition):
|
||||
"""
|
||||
sampling process of Euler sampler
|
||||
-
|
||||
"""
|
||||
batch_size = noise.shape[0]
|
||||
steps = self.timesteps.to(noise.device)
|
||||
cfg_condition = torch.cat([uncondition, condition], dim=0)
|
||||
x = noise
|
||||
state = None
|
||||
pooled_state_list = []
|
||||
for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
|
||||
dt = t_next - t_cur
|
||||
t_cur = t_cur.repeat(batch_size)
|
||||
cfg_x = torch.cat([x, x], dim=0)
|
||||
cfg_t = t_cur.repeat(2)
|
||||
if i in self.recompute_timesteps:
|
||||
state = None
|
||||
out, state = net(cfg_x, cfg_t, cfg_condition, state)
|
||||
if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
|
||||
out = self.guidance_fn(out, self.guidance)
|
||||
else:
|
||||
out = self.guidance_fn(out, 1.0)
|
||||
v = out
|
||||
if i < self.num_steps -1 :
|
||||
x = self.step_fn(x, v, dt, s=0.0, w=0.0)
|
||||
else:
|
||||
x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
|
||||
pooled_state_list.append(state)
|
||||
return x, pooled_state_list
|
||||
|
||||
def __call__(self, net, noise, condition, uncondition):
|
||||
if len(self.recompute_timesteps) != self.num_recompute_timesteps:
|
||||
self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
|
||||
denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
|
||||
return denoised
|
||||
55
src/diffusion/stateful_flow_matching/training.py
Normal file
55
src/diffusion/stateful_flow_matching/training.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class FlowMatchingTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
loss=loss.mean(),
|
||||
)
|
||||
return out
|
||||
122
src/diffusion/stateful_flow_matching/training_adv.py
Normal file
122
src/diffusion/stateful_flow_matching/training_adv.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import torch
|
||||
import math
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, hidden_size):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
|
||||
nn.GroupNorm(num_groups=32, num_channels=hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
|
||||
)
|
||||
|
||||
def forward(self, feature):
|
||||
B, L, C = feature.shape
|
||||
H = W = int(math.sqrt(L))
|
||||
feature = feature.permute(0, 2, 1)
|
||||
feature = feature.view(B, C, H, W)
|
||||
out = self.head(feature).sigmoid().clamp(0.01, 0.99)
|
||||
return out
|
||||
|
||||
class AdvTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
adv_weight=1.0,
|
||||
adv_encoder_layer=4,
|
||||
adv_in_channels=768,
|
||||
adv_hidden_size=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.adv_weight = adv_weight
|
||||
self.adv_encoder_layer = adv_encoder_layer
|
||||
|
||||
self.dis_head = Discriminator(
|
||||
in_channels=adv_in_channels,
|
||||
hidden_size=adv_hidden_size,
|
||||
)
|
||||
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
adv_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
adv_feature.append(output)
|
||||
handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
|
||||
real_feature = adv_feature.pop()
|
||||
net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
|
||||
fake_feature = adv_feature.pop()
|
||||
handle.remove()
|
||||
|
||||
|
||||
real_score_gan = self.dis_head(real_feature.detach())
|
||||
fake_score_gan = self.dis_head(fake_feature.detach())
|
||||
fake_score = self.dis_head(fake_feature)
|
||||
|
||||
loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
|
||||
acc_real = (real_score_gan > 0.5).float()
|
||||
acc_fake = (fake_score_gan < 0.5).float()
|
||||
loss_adv = -torch.log(fake_score)
|
||||
loss_adv_hack = torch.log(fake_score_gan)
|
||||
|
||||
out = dict(
|
||||
adv_loss=loss_adv.mean(),
|
||||
gan_loss=loss_gan.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
|
||||
acc_real=acc_real.mean(),
|
||||
acc_fake=acc_fake.mean(),
|
||||
)
|
||||
return out
|
||||
141
src/diffusion/stateful_flow_matching/training_distill_dino.py
Normal file
141
src/diffusion/stateful_flow_matching/training_distill_dino.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class DistillDINOTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
|
||||
_, s = net(x_t, t, y)
|
||||
src_feature = self.proj(s)
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
out = dict(
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
71
src/diffusion/stateful_flow_matching/training_lpips.py
Normal file
71
src/diffusion/stateful_flow_matching/training_lpips.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class LPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
# self.lpips = torch.compile(self.lpips)
|
||||
no_grad(self.lpips)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out*sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + lpips.mean()*self.lpips_weight,
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
||||
return
|
||||
@@ -0,0 +1,74 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from src.utils.no_grad import no_grad
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
class LPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = False
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
# self.lpips = torch.compile(self.lpips)
|
||||
no_grad(self.lpips)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size = x.shape[0]
|
||||
if self.lognorm_t:
|
||||
t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
|
||||
else:
|
||||
t = torch.rand(batch_size).to(x.device, x.dtype)
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
w = self.scheduler.w(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
out, _ = net(x_t, t, y)
|
||||
|
||||
fm_weight = t*(1-t)**2/0.25
|
||||
lpips_weight = t
|
||||
|
||||
loss = (out - v_t)**2 * fm_weight[:, None, None, None]
|
||||
|
||||
pred_x0 = (x_t + out*sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=loss.mean(),
|
||||
loss=loss.mean() + lpips.mean()*self.lpips_weight,
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
||||
return
|
||||
157
src/diffusion/stateful_flow_matching/training_repa.py
Normal file
157
src/diffusion/stateful_flow_matching/training_repa.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPATrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
|
||||
if getattr(net, "blocks", None) is not None:
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
else:
|
||||
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
out = dict(
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
170
src/diffusion/stateful_flow_matching/training_repa_lpips.py
Normal file
170
src/diffusion/stateful_flow_matching/training_repa_lpips.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import torch
|
||||
import copy
|
||||
import timm
|
||||
from torch.nn import Parameter
|
||||
|
||||
from src.utils.no_grad import no_grad
|
||||
from typing import Callable, Iterator, Tuple
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
from torchmetrics.image.lpip import _NoTrainLpips
|
||||
|
||||
def inverse_sigma(alpha, sigma):
|
||||
return 1/sigma**2
|
||||
def snr(alpha, sigma):
|
||||
return alpha/sigma
|
||||
def minsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, min=threshold)
|
||||
def maxsnr(alpha, sigma, threshold=5):
|
||||
return torch.clip(alpha/sigma, max=threshold)
|
||||
def constant(alpha, sigma):
|
||||
return 1
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load(
|
||||
'/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
|
||||
weight_path,
|
||||
source="local",
|
||||
skip_validation=True
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
return feature
|
||||
|
||||
|
||||
class REPALPIPSTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: BaseScheduler,
|
||||
loss_weight_fn:Callable=constant,
|
||||
feat_loss_weight: float=0.5,
|
||||
lognorm_t=False,
|
||||
lpips_weight=1.0,
|
||||
encoder_weight_path=None,
|
||||
align_layer=8,
|
||||
proj_denoiser_dim=256,
|
||||
proj_hidden_dim=256,
|
||||
proj_encoder_dim=256,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lognorm_t = lognorm_t
|
||||
self.scheduler = scheduler
|
||||
self.loss_weight_fn = loss_weight_fn
|
||||
self.feat_loss_weight = feat_loss_weight
|
||||
self.align_layer = align_layer
|
||||
self.encoder = DINOv2(encoder_weight_path)
|
||||
self.proj_encoder_dim = proj_encoder_dim
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.lpips_weight = lpips_weight
|
||||
self.lpips = _NoTrainLpips(net="vgg")
|
||||
self.lpips = self.lpips.to(torch.bfloat16)
|
||||
no_grad(self.lpips)
|
||||
|
||||
self.proj = nn.Sequential(
|
||||
nn.Sequential(
|
||||
nn.Linear(proj_denoiser_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_hidden_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(proj_hidden_dim, proj_encoder_dim),
|
||||
)
|
||||
)
|
||||
|
||||
def _impl_trainstep(self, net, ema_net, raw_images, x, y):
|
||||
batch_size, c, height, width = x.shape
|
||||
if self.lognorm_t:
|
||||
base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
|
||||
else:
|
||||
base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
|
||||
t = base_t
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
alpha = self.scheduler.alpha(t)
|
||||
dalpha = self.scheduler.dalpha(t)
|
||||
sigma = self.scheduler.sigma(t)
|
||||
dsigma = self.scheduler.dsigma(t)
|
||||
|
||||
x_t = alpha * x + noise * sigma
|
||||
v_t = dalpha * x + dsigma * noise
|
||||
|
||||
src_feature = []
|
||||
def forward_hook(net, input, output):
|
||||
src_feature.append(output)
|
||||
if getattr(net, "blocks", None) is not None:
|
||||
handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
else:
|
||||
handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
|
||||
|
||||
out, _ = net(x_t, t, y)
|
||||
src_feature = self.proj(src_feature[0])
|
||||
handle.remove()
|
||||
|
||||
with torch.no_grad():
|
||||
dst_feature = self.encoder(raw_images)
|
||||
|
||||
if dst_feature.shape[1] != src_feature.shape[1]:
|
||||
dst_length = dst_feature.shape[1]
|
||||
rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
|
||||
dst_height = (dst_length)**0.5 * (height/width)**0.5
|
||||
dst_width = (dst_length)**0.5 * (width/height)**0.5
|
||||
dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
|
||||
dst_feature = dst_feature.permute(0, 3, 1, 2)
|
||||
dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
|
||||
dst_feature = dst_feature.permute(0, 2, 3, 1)
|
||||
dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
|
||||
|
||||
cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
|
||||
cos_loss = 1 - cos_sim
|
||||
|
||||
weight = self.loss_weight_fn(alpha, sigma)
|
||||
fm_loss = weight*(out - v_t)**2
|
||||
|
||||
pred_x0 = (x_t + out * sigma)
|
||||
target_x0 = x
|
||||
# fixbug lpips std
|
||||
lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
|
||||
|
||||
out = dict(
|
||||
lpips_loss=lpips.mean(),
|
||||
fm_loss=fm_loss.mean(),
|
||||
cos_loss=cos_loss.mean(),
|
||||
loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
|
||||
)
|
||||
return out
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
self.proj.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + "proj.",
|
||||
keep_vars=keep_vars)
|
||||
|
||||
162
src/lightning_data.py
Normal file
162
src/lightning_data.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import Any
|
||||
import torch
|
||||
import copy
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
|
||||
from torch.utils.data import DataLoader
|
||||
from src.data.dataset.randn import RandomNDataset
|
||||
from src.data.var_training import VARTransformEngine
|
||||
|
||||
def collate_fn(batch):
|
||||
new_batch = copy.deepcopy(batch)
|
||||
new_batch = list(zip(*new_batch))
|
||||
for i in range(len(new_batch)):
|
||||
if isinstance(new_batch[i][0], torch.Tensor):
|
||||
try:
|
||||
new_batch[i] = torch.stack(new_batch[i], dim=0)
|
||||
except:
|
||||
print("Warning: could not stack tensors")
|
||||
return new_batch
|
||||
|
||||
class DataModule(pl.LightningDataModule):
|
||||
def __init__(self,
|
||||
train_root,
|
||||
test_nature_root,
|
||||
test_gen_root,
|
||||
train_image_size=64,
|
||||
train_batch_size=64,
|
||||
train_num_workers=8,
|
||||
var_transform_engine: VARTransformEngine = None,
|
||||
train_prefetch_factor=2,
|
||||
train_dataset: str = None,
|
||||
eval_batch_size=32,
|
||||
eval_num_workers=4,
|
||||
eval_max_num_instances=50000,
|
||||
pred_batch_size=32,
|
||||
pred_num_workers=4,
|
||||
pred_seeds:str=None,
|
||||
pred_selected_classes=None,
|
||||
num_classes=1000,
|
||||
latent_shape=(4,64,64),
|
||||
):
|
||||
super().__init__()
|
||||
pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None
|
||||
|
||||
self.train_root = train_root
|
||||
self.train_image_size = train_image_size
|
||||
self.train_dataset = train_dataset
|
||||
# stupid data_convert override, just to make nebular happy
|
||||
self.train_batch_size = train_batch_size
|
||||
self.train_num_workers = train_num_workers
|
||||
self.train_prefetch_factor = train_prefetch_factor
|
||||
|
||||
self.test_nature_root = test_nature_root
|
||||
self.test_gen_root = test_gen_root
|
||||
self.eval_max_num_instances = eval_max_num_instances
|
||||
self.pred_seeds = pred_seeds
|
||||
self.num_classes = num_classes
|
||||
self.latent_shape = latent_shape
|
||||
|
||||
self.eval_batch_size = eval_batch_size
|
||||
self.pred_batch_size = pred_batch_size
|
||||
|
||||
self.pred_num_workers = pred_num_workers
|
||||
self.eval_num_workers = eval_num_workers
|
||||
|
||||
self.pred_selected_classes = pred_selected_classes
|
||||
|
||||
self._train_dataloader = None
|
||||
self.var_transform_engine = var_transform_engine
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
if stage == "fit":
|
||||
assert self.train_dataset is not None
|
||||
if self.train_dataset == "pix_imagenet64":
|
||||
from src.data.dataset.imagenet import PixImageNet64
|
||||
self.train_dataset = PixImageNet64(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "pix_imagenet128":
|
||||
from src.data.dataset.imagenet import PixImageNet128
|
||||
self.train_dataset = PixImageNet128(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "imagenet256":
|
||||
from src.data.dataset.imagenet import ImageNet256
|
||||
self.train_dataset = ImageNet256(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "pix_imagenet256":
|
||||
from src.data.dataset.imagenet import PixImageNet256
|
||||
self.train_dataset = PixImageNet256(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "imagenet512":
|
||||
from src.data.dataset.imagenet import ImageNet512
|
||||
self.train_dataset = ImageNet512(
|
||||
root=self.train_root,
|
||||
)
|
||||
elif self.train_dataset == "pix_imagenet512":
|
||||
from src.data.dataset.imagenet import PixImageNet512
|
||||
self.train_dataset = PixImageNet512(
|
||||
root=self.train_root,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("no such dataset")
|
||||
|
||||
def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
|
||||
if self.var_transform_engine and self.trainer.training:
|
||||
batch = self.var_transform_engine(batch)
|
||||
return batch
|
||||
|
||||
def train_dataloader(self) -> TRAIN_DATALOADERS:
|
||||
global_rank = self.trainer.global_rank
|
||||
world_size = self.trainer.world_size
|
||||
from torch.utils.data import DistributedSampler
|
||||
sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
|
||||
self._train_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
self.train_batch_size,
|
||||
timeout=6000,
|
||||
num_workers=self.train_num_workers,
|
||||
prefetch_factor=self.train_prefetch_factor,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
return self._train_dataloader
|
||||
|
||||
def val_dataloader(self) -> EVAL_DATALOADERS:
|
||||
global_rank = self.trainer.global_rank
|
||||
world_size = self.trainer.world_size
|
||||
self.eval_dataset = RandomNDataset(
|
||||
latent_shape=self.latent_shape,
|
||||
num_classes=self.num_classes,
|
||||
max_num_instances=self.eval_max_num_instances,
|
||||
)
|
||||
from torch.utils.data import DistributedSampler
|
||||
sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
|
||||
return DataLoader(self.eval_dataset, self.eval_batch_size,
|
||||
num_workers=self.eval_num_workers,
|
||||
prefetch_factor=2,
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler
|
||||
)
|
||||
|
||||
def predict_dataloader(self) -> EVAL_DATALOADERS:
|
||||
global_rank = self.trainer.global_rank
|
||||
world_size = self.trainer.world_size
|
||||
self.pred_dataset = RandomNDataset(
|
||||
seeds= self.pred_seeds,
|
||||
max_num_instances=50000,
|
||||
num_classes=self.num_classes,
|
||||
selected_classes=self.pred_selected_classes,
|
||||
latent_shape=self.latent_shape,
|
||||
)
|
||||
from torch.utils.data import DistributedSampler
|
||||
sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
|
||||
return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size,
|
||||
num_workers=self.pred_num_workers,
|
||||
prefetch_factor=4,
|
||||
collate_fn=collate_fn,
|
||||
sampler=sampler
|
||||
)
|
||||
123
src/lightning_model.py
Normal file
123
src/lightning_model.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
|
||||
import os.path
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim import Optimizer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
|
||||
from src.models.vae import BaseVAE, fp2uint8
|
||||
from src.models.conditioner import BaseConditioner
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.callbacks.simple_ema import SimpleEMA
|
||||
from src.diffusion.base.sampling import BaseSampler
|
||||
from src.diffusion.base.training import BaseTrainer
|
||||
from src.utils.no_grad import no_grad, filter_nograd_tensors
|
||||
from src.utils.copy import copy_params
|
||||
|
||||
EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
|
||||
OptimizerCallable = Callable[[Iterable], Optimizer]
|
||||
LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
vae: BaseVAE,
|
||||
conditioner: BaseConditioner,
|
||||
denoiser: nn.Module,
|
||||
diffusion_trainer: BaseTrainer,
|
||||
diffusion_sampler: BaseSampler,
|
||||
ema_tracker: Optional[EMACallable] = None,
|
||||
optimizer: OptimizerCallable = None,
|
||||
lr_scheduler: LRSchedulerCallable = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.conditioner = conditioner
|
||||
self.denoiser = denoiser
|
||||
self.ema_denoiser = copy.deepcopy(self.denoiser)
|
||||
self.diffusion_sampler = diffusion_sampler
|
||||
self.diffusion_trainer = diffusion_trainer
|
||||
self.ema_tracker = ema_tracker
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
# self.model_loader = ModelLoader()
|
||||
|
||||
self._strict_loading = False
|
||||
|
||||
def configure_model(self) -> None:
|
||||
self.trainer.strategy.barrier()
|
||||
# self.denoiser = self.model_loader.load(self.denoiser)
|
||||
copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
|
||||
|
||||
# self.denoiser = torch.compile(self.denoiser)
|
||||
# disable grad for conditioner and vae
|
||||
no_grad(self.conditioner)
|
||||
no_grad(self.vae)
|
||||
no_grad(self.diffusion_sampler)
|
||||
no_grad(self.ema_denoiser)
|
||||
|
||||
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
|
||||
ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
|
||||
return [ema_tracker]
|
||||
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
|
||||
params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
|
||||
optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
|
||||
if self.lr_scheduler is None:
|
||||
return dict(
|
||||
optimizer=optimizer
|
||||
)
|
||||
else:
|
||||
lr_scheduler = self.lr_scheduler(optimizer)
|
||||
return dict(
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
raw_images, x, y = batch
|
||||
with torch.no_grad():
|
||||
x = self.vae.encode(x)
|
||||
condition, uncondition = self.conditioner(y)
|
||||
loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
|
||||
self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
|
||||
return loss["loss"]
|
||||
|
||||
def predict_step(self, batch, batch_idx):
|
||||
xT, y, metadata = batch
|
||||
with torch.no_grad():
|
||||
condition, uncondition = self.conditioner(y)
|
||||
# Sample images:
|
||||
samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
|
||||
samples = self.vae.decode(samples)
|
||||
# fp32 -1,1 -> uint8 0,255
|
||||
samples = fp2uint8(samples)
|
||||
return samples
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
samples = self.predict_step(batch, batch_idx)
|
||||
return samples
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
if destination is None:
|
||||
destination = {}
|
||||
self._save_to_state_dict(destination, prefix, keep_vars)
|
||||
self.denoiser.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"denoiser.",
|
||||
keep_vars=keep_vars)
|
||||
self.ema_denoiser.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"ema_denoiser.",
|
||||
keep_vars=keep_vars)
|
||||
self.diffusion_trainer.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix+"diffusion_trainer.",
|
||||
keep_vars=keep_vars)
|
||||
return destination
|
||||
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
26
src/models/conditioner.py
Normal file
26
src/models/conditioner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseConditioner(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseConditioner, self).__init__()
|
||||
|
||||
def _impl_condition(self, y):
|
||||
...
|
||||
def _impl_uncondition(self, y):
|
||||
...
|
||||
def __call__(self, y):
|
||||
condition = self._impl_condition(y)
|
||||
uncondition = self._impl_uncondition(y)
|
||||
return condition, uncondition
|
||||
|
||||
class LabelConditioner(BaseConditioner):
|
||||
def __init__(self, null_class):
|
||||
super().__init__()
|
||||
self.null_condition = null_class
|
||||
|
||||
def _impl_condition(self, y):
|
||||
return torch.tensor(y).long().cuda()
|
||||
|
||||
def _impl_uncondition(self, y):
|
||||
return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()
|
||||
0
src/models/denoiser/__init__.py
Normal file
0
src/models/denoiser/__init__.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.utils.no_grad import no_grad
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
|
||||
return (pos_rope, pos_ape)
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
s = self.s_embedder(x)
|
||||
# s = s + pos_ape.to(s.dtype)
|
||||
for i in range(self.num_blocks):
|
||||
s = self.blocks[i](s, c, pos_rope, mask)
|
||||
return None, s
|
||||
|
||||
|
||||
class FlattenDiTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
|
||||
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, s, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
|
||||
c = torch.nn.functional.silu(t + y)
|
||||
x = torch.cat([x, s], dim=-1)
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_blocks):
|
||||
x = self.blocks[i](x, c, pos, None)
|
||||
x = self.final_layer(x, c)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
|
||||
stride=self.patch_size)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
ModelLoader().load(encoder)
|
||||
self.encoder = self.encoder.to(torch.bfloat16)
|
||||
no_grad(self.encoder)
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
with torch.no_grad():
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
@@ -0,0 +1,447 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.utils.no_grad import no_grad
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(groups, dim)
|
||||
self.norm2 = nn.GroupNorm(groups, dim)
|
||||
self.embed_proj = nn.Linear(hidden_dim, dim)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = self.embed_proj(c)[:, :, None, None]
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
x = x * c
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return residual + x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
|
||||
return (pos_rope, pos_ape)
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, mask=None, classify_layer=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
s = self.s_embedder(x)
|
||||
# s = s + pos_ape.to(s.dtype)
|
||||
classify_feats = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.blocks[i](s, c, pos_rope, mask)
|
||||
if classify_layer is not None and i < classify_layer:
|
||||
classify_feats.append(s)
|
||||
if i == classify_layer - 1:
|
||||
return _, classify_feats
|
||||
return None, s
|
||||
|
||||
|
||||
class FlattenDiTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_mid_blocks=18,
|
||||
num_res_blocks=[1, 1, 1],
|
||||
num_res_channels=[64, 384, 768],
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_mid_blocks = num_mid_blocks
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.num_res_channels = num_res_channels
|
||||
self.patch_size = 2**(len(num_res_blocks))
|
||||
|
||||
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.down_res_blocks = nn.ModuleList()
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.down_res_blocks.append(
|
||||
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
|
||||
)
|
||||
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = []
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.up_res_blocks.append(
|
||||
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
|
||||
)
|
||||
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
|
||||
])
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
self.weight_path = weight_path
|
||||
self.load_ema = load_ema
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, s, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, self.hidden_size)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
|
||||
c = torch.nn.functional.silu(t + y)
|
||||
|
||||
residual = []
|
||||
for i, block in enumerate(self.down_res_blocks):
|
||||
if isinstance(block, nn.Conv2d):
|
||||
residual.append(x)
|
||||
x = block(x)
|
||||
else:
|
||||
x = block(x, c)
|
||||
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
|
||||
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
|
||||
for i in range(self.num_mid_blocks):
|
||||
x = self.blocks[i](x, mid_c, pos, None)
|
||||
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
|
||||
|
||||
residual[0] = 0.0
|
||||
for i, block in enumerate(self.up_res_blocks):
|
||||
if isinstance(block, nn.ConvTranspose2d):
|
||||
x = block(x) + residual.pop()
|
||||
else:
|
||||
x = block(x, c)
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
ModelLoader().load(encoder)
|
||||
self.encoder = self.encoder.to(torch.bfloat16)
|
||||
no_grad(self.encoder)
|
||||
|
||||
def forward(self, x, t, y, s=None, classify_layer=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
|
||||
if classify_layer is not None:
|
||||
return None, s
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
class FlattenDiT_jointtraining(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
@@ -0,0 +1,448 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.utils.no_grad import no_grad
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(groups, dim)
|
||||
self.norm2 = nn.GroupNorm(groups, dim)
|
||||
self.embed_proj = nn.Linear(hidden_dim, dim)
|
||||
|
||||
def forward(self, x, c):
|
||||
c = self.embed_proj(c)[:, :, None, None]
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
x = x * c
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return residual + x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
|
||||
return (pos_rope, pos_ape)
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, mask=None, classify_layer=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
s = self.s_embedder(x)
|
||||
# s = s + pos_ape.to(s.dtype)
|
||||
classify_feats = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.blocks[i](s, c, pos_rope, mask)
|
||||
if classify_layer is not None and i < classify_layer:
|
||||
classify_feats.append(s)
|
||||
if i == classify_layer - 1:
|
||||
return _, classify_feats
|
||||
return None, s
|
||||
|
||||
|
||||
class FlattenDiTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_mid_blocks=18,
|
||||
num_res_blocks=[1, 1, 1],
|
||||
num_res_channels=[64, 384, 768],
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_mid_blocks = num_mid_blocks
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.num_res_channels = num_res_channels
|
||||
self.patch_size = 2**(len(num_res_blocks))
|
||||
|
||||
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.down_res_blocks = nn.ModuleList()
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.down_res_blocks.append(
|
||||
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
|
||||
)
|
||||
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = []
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.up_res_blocks.append(
|
||||
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
|
||||
)
|
||||
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
|
||||
])
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
self.weight_path = weight_path
|
||||
self.load_ema = load_ema
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, s, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, self.hidden_size)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
|
||||
c = torch.nn.functional.silu(t + y)
|
||||
|
||||
residual = []
|
||||
for i, block in enumerate(self.down_res_blocks):
|
||||
if isinstance(block, nn.Conv2d):
|
||||
residual.append(x)
|
||||
x = block(x)
|
||||
else:
|
||||
x = block(x, c)
|
||||
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
|
||||
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
|
||||
for i in range(self.num_mid_blocks):
|
||||
x = self.blocks[i](x, mid_c, pos, None)
|
||||
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
|
||||
|
||||
residual[0] = 0.0
|
||||
for i, block in enumerate(self.up_res_blocks):
|
||||
if isinstance(block, nn.ConvTranspose2d):
|
||||
x = block(x) + residual.pop()
|
||||
else:
|
||||
x = block(x, c)
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
ModelLoader().load(encoder, "encoder.")
|
||||
ModelLoader().load(decoder, "decoder.")
|
||||
self.encoder = self.encoder.to(torch.bfloat16)
|
||||
no_grad(self.encoder)
|
||||
|
||||
def forward(self, x, t, y, s=None, classify_layer=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
|
||||
if classify_layer is not None:
|
||||
return None, s
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
class FlattenDiT_jointtraining(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.utils.no_grad import no_grad
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
|
||||
self.norm1 = nn.GroupNorm(groups, dim)
|
||||
self.norm2 = nn.GroupNorm(groups, dim)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = torch.nn.functional.silu(x)
|
||||
return residual + x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
|
||||
return (pos_rope, pos_ape)
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, mask=None, classify_layer=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
s = self.s_embedder(x)
|
||||
# s = s + pos_ape.to(s.dtype)
|
||||
classify_feats = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.blocks[i](s, c, pos_rope, mask)
|
||||
if classify_layer is not None and i < classify_layer:
|
||||
classify_feats.append(s)
|
||||
if i == classify_layer - 1:
|
||||
return _, classify_feats
|
||||
return None, s
|
||||
|
||||
|
||||
class FlattenDiTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_mid_blocks=18,
|
||||
num_res_blocks=[1, 1, 1],
|
||||
num_res_channels=[64, 384, 768],
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_mid_blocks = num_mid_blocks
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.num_res_channels = num_res_channels
|
||||
self.patch_size = 2**(len(num_res_blocks))
|
||||
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.down_res_blocks = nn.ModuleList()
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.down_res_blocks.append(
|
||||
nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
|
||||
)
|
||||
self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = []
|
||||
previous_channel = self.in_channels
|
||||
for num, channels in zip(num_res_blocks, num_res_channels):
|
||||
self.up_res_blocks.append(
|
||||
nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
|
||||
)
|
||||
self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
|
||||
previous_channel = channels
|
||||
self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
|
||||
])
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
self.weight_path = weight_path
|
||||
self.load_ema = load_ema
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
# Zero-out adaLN modulation layers in SiT blocks:
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
for block in self.down_res_blocks:
|
||||
if isinstance(block, ResBlock):
|
||||
nn.init.constant_(block.conv1.weight, 0)
|
||||
nn.init.constant_(block.conv1.bias, 0)
|
||||
nn.init.constant_(block.norm1.weight, 0)
|
||||
nn.init.constant_(block.norm2.weight, 0)
|
||||
nn.init.constant_(block.conv2.weight, 0)
|
||||
nn.init.constant_(block.conv2.bias, 0)
|
||||
|
||||
for block in self.up_res_blocks:
|
||||
if isinstance(block, ResBlock):
|
||||
nn.init.constant_(block.conv1.weight, 0)
|
||||
nn.init.constant_(block.conv1.bias, 0)
|
||||
nn.init.constant_(block.norm1.weight, 0)
|
||||
nn.init.constant_(block.norm2.weight, 0)
|
||||
nn.init.constant_(block.conv2.weight, 0)
|
||||
nn.init.constant_(block.conv2.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, s, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
|
||||
|
||||
residual = []
|
||||
for i, block in enumerate(self.down_res_blocks):
|
||||
if isinstance(block, nn.Conv2d):
|
||||
residual.append(x)
|
||||
x = block(x)
|
||||
else:
|
||||
x = block(x)
|
||||
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = x.view(B, self.hidden_size, -1).transpose(1, 2)
|
||||
mid_c = torch.nn.functional.silu(t[:, None, :] + s)
|
||||
for i in range(self.num_mid_blocks):
|
||||
x = self.blocks[i](x, mid_c, pos, None)
|
||||
x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
|
||||
|
||||
residual[0] = 0.0
|
||||
for i, block in enumerate(self.up_res_blocks):
|
||||
if isinstance(block, nn.ConvTranspose2d):
|
||||
x = block(x) + residual.pop()
|
||||
else:
|
||||
x = block(x)
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
ModelLoader().load(encoder, "encoder.")
|
||||
ModelLoader().load(decoder, "decoder.")
|
||||
self.encoder = self.encoder.to(torch.bfloat16)
|
||||
no_grad(self.encoder)
|
||||
|
||||
def forward(self, x, t, y, s=None, classify_layer=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y, classify_layer=classify_layer)
|
||||
if classify_layer is not None:
|
||||
return None, s
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
class FlattenDiT_jointtraining(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
274
src/models/denoiser/condit_dit.py
Normal file
274
src/models/denoiser/condit_dit.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from numba.cuda.cudadrv.devicearray import lru_cache
|
||||
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
flex_attention = torch.compile(flex_attention)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = self.norm(x)
|
||||
x = modulate(x, shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, dim)
|
||||
self.act = nn.GELU(approximate="tanh")
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16):
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
x_pos = x_pos.reshape(-1)
|
||||
y_pos = y_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
# import pdb; pdb.set_trace()
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q).to(q.dtype)
|
||||
k = self.k_norm(k).to(k.dtype)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
# x = flex_attention(q, k, v, block_mask=mask)
|
||||
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
x = scaled_dot_product_attention(q, k, v, mask)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class ConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
|
||||
self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2)
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
|
||||
|
||||
self.weight_path = weight_path
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
|
||||
@lru_cache
|
||||
def fetch_pos(self, height, width, device):
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...]
|
||||
return pos
|
||||
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H, W, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
|
||||
if s is None:
|
||||
# semantic encoder
|
||||
s = self.s_embedder(x) + pos
|
||||
c = nn.functional.silu(t + y)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
314
src/models/denoiser/flatten_condit_catdit_fixt.py
Normal file
314
src/models/denoiser/flatten_condit_catdit_fixt.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, mask)
|
||||
# s = nn.functional.silu(t + s)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
|
||||
x = torch.cat((x, s), dim=-1)
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, c, pos, None)
|
||||
x = self.final_layer(x, c)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
340
src/models/denoiser/flatten_condit_conv_fixt.py
Normal file
340
src/models/denoiser/flatten_condit_conv_fixt.py
Normal file
@@ -0,0 +1,340 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
class FlattenConvBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
|
||||
attn_x = self.attn(attn_x)
|
||||
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
|
||||
x = x + gate_msa * attn_x
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
kernel_size=3,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([])
|
||||
for i in range(self.num_cond_blocks):
|
||||
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
|
||||
for i in range(self.num_blocks-self.num_cond_blocks):
|
||||
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size))
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, None)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
339
src/models/denoiser/flatten_condit_convnext_fixt.py
Normal file
339
src/models/denoiser/flatten_condit_convnext_fixt.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
class FlattenConvBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
||||
attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
|
||||
attn_x = self.attn(attn_x)
|
||||
attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
|
||||
x = x + gate_msa * attn_x
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([])
|
||||
for i in range(self.num_cond_blocks):
|
||||
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
|
||||
for i in range(self.num_blocks-self.num_cond_blocks):
|
||||
self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups))
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, None)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
313
src/models/denoiser/flatten_condit_dit_fixt.py
Normal file
313
src/models/denoiser/flatten_condit_dit_fixt.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, mask)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
314
src/models/denoiser/flatten_condit_dit_norm_fixt.py
Normal file
314
src/models/denoiser/flatten_condit_dit_norm_fixt.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, mask)
|
||||
s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
429
src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
Normal file
429
src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
Normal file
@@ -0,0 +1,429 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from src.utils.model_loader import ModelLoader
|
||||
from src.utils.no_grad import no_grad
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiTEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
|
||||
return (pos_rope, pos_ape)
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
def forward(self, x, t, y, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
s = self.s_embedder(x)
|
||||
# s = s + pos_ape.to(s.dtype)
|
||||
for i in range(self.num_blocks):
|
||||
s = self.blocks[i](s, c, pos_rope, mask)
|
||||
return None, s
|
||||
|
||||
|
||||
class FlattenDiTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)]
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, s, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
# s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
|
||||
s = torch.nn.functional.silu(t + s)
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
|
||||
stride=self.patch_size)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
joint_training=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
ModelLoader().load(encoder)
|
||||
if not joint_training:
|
||||
self.encoder = self.encoder.to(torch.bfloat16)
|
||||
no_grad(self.encoder)
|
||||
|
||||
self.joint_training = joint_training
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
class FlattenDiTScalingEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder:FlattenDiTEncoder,
|
||||
decoder:FlattenDiTDecoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
no_grad(self.decoder)
|
||||
|
||||
if self.encoder.weight_path:
|
||||
weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu'))
|
||||
if self.encoder.load_ema:
|
||||
prefix = "ema_denoiser."
|
||||
else:
|
||||
prefix = "denoiser."
|
||||
for k, v in self.encoder.state_dict().items():
|
||||
try:
|
||||
v.copy_(weight["state_dict"][prefix+k])
|
||||
except:
|
||||
print(f"Failed to copy {prefix+k} to denoiser weight")
|
||||
|
||||
if self.decoder.weight_path:
|
||||
weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu'))
|
||||
if self.decoder.load_ema:
|
||||
prefix = "ema_denoiser."
|
||||
else:
|
||||
prefix = "denoiser."
|
||||
for k, v in self.decoder.state_dict().items():
|
||||
if "blocks." in k:
|
||||
blockid = int(k.split("blocks.")[-1][0])
|
||||
k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}")
|
||||
try:
|
||||
v.copy_(weight["state_dict"][prefix+k])
|
||||
except:
|
||||
print(f"Failed to copy {prefix+k} to denoiser weight")
|
||||
self.decoder = decoder.to(torch.bfloat16)
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
if s is None:
|
||||
_, s = self.encoder(x, t, y)
|
||||
x = self.decoder(x, t, y, s)
|
||||
return x, s
|
||||
|
||||
334
src/models/denoiser/flatten_condit_mlp_fixt.py
Normal file
334
src/models/denoiser/flatten_condit_mlp_fixt.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
class FlattenMLPBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([])
|
||||
for i in range(self.num_cond_blocks):
|
||||
self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
|
||||
for i in range(self.num_blocks-self.num_cond_blocks):
|
||||
self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups))
|
||||
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.s_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, None)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
321
src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
Normal file
321
src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.s_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.s_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device)
|
||||
s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2)
|
||||
s = self.s_embedder(s)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos_s, mask)
|
||||
s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size)
|
||||
s = torch.permute(s, (0, 3, 1, 2))
|
||||
s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
s = torch.permute(s, (0, 2, 3, 1))
|
||||
s = s.view(B, -1, self.hidden_size)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos_x, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
306
src/models/denoiser/flatten_dit_fixt.py
Normal file
306
src/models/denoiser/flatten_dit_fixt.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device, dtype):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, masks=None):
|
||||
if masks is None:
|
||||
masks = [None, ]*self.num_blocks
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.unbind(0)
|
||||
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
|
||||
masks = masks + [None]*(self.num_blocks-len(masks))
|
||||
|
||||
B, _, H, W = x.shape
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
x = self.x_embedder(x)
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
||||
B, L, C = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||
y = self.y_embedder(y).view(B, 1, C)
|
||||
condition = nn.functional.silu(t + y)
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, condition, pos, masks[i])
|
||||
x = self.final_layer(x, condition)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x
|
||||
311
src/models/denoiser/flatten_dit_fixt_xvout.py
Normal file
311
src/models/denoiser/flatten_dit_fixt_xvout.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device, dtype):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device, dtype)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def forward(self, x, t, y, masks=None):
|
||||
if masks is None:
|
||||
masks = [None, ]*self.num_blocks
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.unbind(0)
|
||||
if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
|
||||
masks = masks + [None]*(self.num_blocks-len(masks))
|
||||
|
||||
B, _, H, W = x.shape
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
x = self.x_embedder(x)
|
||||
pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
|
||||
B, L, C = x.shape
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, C)
|
||||
y = self.y_embedder(y).view(B, 1, C)
|
||||
condition = nn.functional.silu(t + y)
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, condition, pos, masks[i])
|
||||
x = self.final_layer(x, condition)
|
||||
x0, v = x.chunk(2, dim=-1)
|
||||
x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
if self.training:
|
||||
return v, x0
|
||||
else:
|
||||
return v
|
||||
308
src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
Normal file
308
src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
Normal file
@@ -0,0 +1,308 @@
|
||||
import functools
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn.modules.module import T
|
||||
|
||||
# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class Embed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
def __init__(self, num_classes, hidden_size):
|
||||
super().__init__()
|
||||
self.embedding_table = nn.Embedding(num_classes, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, labels,):
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 2*hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
return x
|
||||
|
||||
def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
|
||||
# assert H * H == end
|
||||
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
|
||||
x_pos = torch.linspace(0, scale, width)
|
||||
y_pos = torch.linspace(0, scale, height)
|
||||
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||
y_pos = y_pos.reshape(-1)
|
||||
x_pos = x_pos.reshape(-1)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
|
||||
freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
|
||||
freqs_cis = freqs_cis.reshape(height*width, -1)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
freqs_cis = freqs_cis[None, :, None, :]
|
||||
# xq : B N H Hc
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class RAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = True,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
norm_layer: nn.Module = RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis=pos)
|
||||
q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
|
||||
k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
|
||||
v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
|
||||
|
||||
x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class FlattenDiTBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c, pos, mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
|
||||
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FlattenConDiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4,
|
||||
num_groups=12,
|
||||
hidden_size=1152,
|
||||
num_blocks=18,
|
||||
num_cond_blocks=4,
|
||||
patch_size=2,
|
||||
num_classes=1000,
|
||||
learn_sigma=True,
|
||||
deep_supervision=0,
|
||||
weight_path=None,
|
||||
load_ema=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.deep_supervision = deep_supervision
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.num_blocks = num_blocks
|
||||
self.num_cond_blocks = num_cond_blocks
|
||||
self.patch_size = patch_size
|
||||
self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
|
||||
|
||||
self.weight_path = weight_path
|
||||
|
||||
self.load_ema = load_ema
|
||||
self.blocks = nn.ModuleList([
|
||||
FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.initialize_weights()
|
||||
self.precompute_pos = dict()
|
||||
|
||||
def fetch_pos(self, height, width, device):
|
||||
if (height, width) in self.precompute_pos:
|
||||
return self.precompute_pos[(height, width)].to(device)
|
||||
else:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
|
||||
self.precompute_pos[(height, width)] = pos
|
||||
return pos
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
|
||||
# Initialize label embedding table:
|
||||
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# # Zero-out adaLN modulation layers in SiT blocks:
|
||||
# for block in self.blocks:
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
|
||||
def forward(self, x, t, y, s=None, mask=None):
|
||||
B, _, H, W = x.shape
|
||||
pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
|
||||
x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
|
||||
y = self.y_embedder(y).view(B, 1, self.hidden_size)
|
||||
c = nn.functional.silu(t + y)
|
||||
if s is None:
|
||||
s = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks):
|
||||
s = self.blocks[i](s, c, pos, mask)
|
||||
s = nn.functional.silu(t + s)
|
||||
|
||||
x = self.x_embedder(x)
|
||||
for i in range(self.num_cond_blocks, self.num_blocks):
|
||||
x = self.blocks[i](x, s, pos, None)
|
||||
x = self.final_layer(x, s)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return x, s
|
||||
160
src/models/denoiser/flowdcn.py
Normal file
160
src/models/denoiser/flowdcn.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torch.nn.init import zeros_
|
||||
from src.models.denoiser.base_model import BaseModel
|
||||
from src.ops.triton_kernels.function import DCNFunction
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale[:, None, None]) + shift[:, None, None]
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer = None,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
def forward(self, x):
|
||||
b, h, w, c = x.shape
|
||||
x = x.view(b, h*w, c)
|
||||
x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
x = x.view(b, h, w, c)
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleDCN(nn.Module):
|
||||
def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.groups = groups
|
||||
self.channels = channels
|
||||
self.kernels = kernels
|
||||
self.v = nn.Linear(in_channels, groups * channels, bias=True)
|
||||
self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
|
||||
self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
|
||||
self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
|
||||
self.out = nn.Linear(groups * channels, in_channels)
|
||||
self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
|
||||
self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
|
||||
self.max_scale = 6
|
||||
self._init_weights()
|
||||
def _init_weights(self):
|
||||
zeros_(self.qk_deformables.weight.data)
|
||||
zeros_(self.qk_scales.weight.data)
|
||||
zeros_(self.qk_deformables.bias.data)
|
||||
zeros_(self.qk_weights.weight.data)
|
||||
zeros_(self.v.bias.data)
|
||||
zeros_(self.out.bias.data)
|
||||
num_prior = int(self.kernels ** 0.5)
|
||||
dx = torch.linspace(-1, 1, num_prior, device="cuda")
|
||||
dy = torch.linspace(-1, 1, num_prior, device="cuda")
|
||||
dxy = torch.meshgrid([dx, dy], indexing="xy")
|
||||
dxy = torch.stack(dxy, dim=-1)
|
||||
dxy = dxy.view(-1, 2)
|
||||
self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
|
||||
for i in range(self.groups):
|
||||
scale = (i+1)/self.groups - 0.0001
|
||||
inv_scale = math.log((scale)/(1-scale))
|
||||
self.deformables_scale.data[..., i, :, :] = inv_scale
|
||||
def forward(self, x):
|
||||
B, H, W, _ = x.shape
|
||||
v = self.v(x).view(B, H, W, self.groups, self.channels)
|
||||
deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
|
||||
scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
|
||||
deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
|
||||
weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
|
||||
out = DCNFunction.apply(v, deformables, weights)
|
||||
out = out.view(B, H, W, -1)
|
||||
out = self.out(out)
|
||||
return out
|
||||
|
||||
class FlowDCNBlock(nn.Module):
|
||||
def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True):
|
||||
super().__init__()
|
||||
self.norm1 = RMSNorm(hidden_size, eps=1e-6)
|
||||
self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass)
|
||||
self.norm2 = RMSNorm(hidden_size, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class FlowDCN(BaseModel):
|
||||
def __init__(self, deformable_biass=True, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.blocks = nn.ModuleList([
|
||||
FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks)
|
||||
])
|
||||
self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True)
|
||||
self.initialize_weights()
|
||||
|
||||
def forward(self, x, t, y):
|
||||
batch_size, _, height, width = x.shape[0]
|
||||
x = self.x_embedder(x) # (N, D, h, w)
|
||||
x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1)
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
c = t + y # (N, D)
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, height//self.patch_size, width//self.patch_size, C)
|
||||
for block in self.blocks:
|
||||
x = block(x, c) # (N, T, D)
|
||||
x = x.view(B, L, C)
|
||||
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
if self.learn_sigma:
|
||||
x, _ = torch.split(x, self.out_channels // 2, dim=1)
|
||||
return x
|
||||
132
src/models/encoder.py
Normal file
132
src/models/encoder.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import torch
|
||||
import copy
|
||||
import os
|
||||
import timm
|
||||
import transformers
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
class RandViT(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str=None):
|
||||
super(RandViT, self).__init__()
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
class DINO(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(DINO, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([ 0.0,
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([ 1.0,
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(CLIP, self).__init__()
|
||||
self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
|
||||
self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0,
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0,
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder(x)['last_hidden_state'][:, 1:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
|
||||
self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
81
src/models/vae.py
Normal file
81
src/models/vae.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import subprocess
|
||||
import lightning.pytorch as pl
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
def class_fn_from_str(class_str):
|
||||
class_module, from_class = class_str.rsplit(".", 1)
|
||||
class_module = __import__(class_module, fromlist=[from_class])
|
||||
return getattr(class_module, from_class)
|
||||
|
||||
|
||||
class BaseVAE(torch.nn.Module):
|
||||
def __init__(self, scale=1.0, shift=0.0):
|
||||
super().__init__()
|
||||
self.model = torch.nn.Identity()
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
|
||||
def encode(self, x):
|
||||
return x/self.scale+self.shift
|
||||
|
||||
def decode(self, x):
|
||||
return (x-self.shift)*self.scale
|
||||
|
||||
|
||||
# very bad bugs with nearest sampling
|
||||
class DownSampleVAE(BaseVAE):
|
||||
def __init__(self, down_ratio, scale=1.0, shift=0.0):
|
||||
super().__init__()
|
||||
self.model = torch.nn.Identity()
|
||||
self.scale = scale
|
||||
self.shift = shift
|
||||
self.down_ratio = down_ratio
|
||||
|
||||
def encode(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
|
||||
return x/self.scale+self.shift
|
||||
|
||||
def decode(self, x):
|
||||
x = (x-self.shift)*self.scale
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class LatentVAE(BaseVAE):
|
||||
def __init__(self, precompute=False, weight_path:str=None):
|
||||
super().__init__()
|
||||
self.precompute = precompute
|
||||
self.model = None
|
||||
self.weight_path = weight_path
|
||||
|
||||
from diffusers.models import AutoencoderKL
|
||||
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
|
||||
self.scaling_factor = self.model.config.scaling_factor
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x):
|
||||
assert self.model is not None
|
||||
if self.precompute:
|
||||
return x.mul_(self.scaling_factor)
|
||||
return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x):
|
||||
assert self.model is not None
|
||||
return self.model.decode(x.div_(self.scaling_factor)).sample
|
||||
|
||||
|
||||
def uint82fp(x):
|
||||
x = x.to(torch.float32)
|
||||
x = (x - 127.5) / 127.5
|
||||
return x
|
||||
|
||||
def fp2uint8(x):
|
||||
x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
|
||||
return x
|
||||
|
||||
346
src/ops/cuda_kernels/backward.cu
Normal file
346
src/ops/cuda_kernels/backward.cu
Normal file
@@ -0,0 +1,346 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector_types.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.hpp>
|
||||
#include <cuda_bf16.hpp>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <cooperative_groups/memcpy_async.h>
|
||||
#include <cuda/pipeline>
|
||||
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
template<typename scalar_t>
|
||||
__device__ __always_inline int toInt(scalar_t val);
|
||||
|
||||
template<>
|
||||
__device__ __always_inline int toInt(float val){
|
||||
return static_cast<int>(val);
|
||||
}
|
||||
template<>
|
||||
__device__ __always_inline int toInt(half val){
|
||||
return __half2int_rz(val);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__device__ __always_inline scalar_t fromInt(int val);
|
||||
|
||||
template<>
|
||||
__device__ __always_inline float fromInt(int val){
|
||||
return static_cast<float>(val);
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __always_inline half fromInt(int val){
|
||||
return __int2half_rz(val);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__device__ __always_inline scalar_t constVal(float val);
|
||||
|
||||
template<>
|
||||
__device__ __always_inline float constVal<float>(float val) {
|
||||
return (float)val;
|
||||
}
|
||||
|
||||
template<>
|
||||
__device__ __always_inline half constVal<half>(float val) {
|
||||
return __float2half(val); // Using float to half conversion
|
||||
}
|
||||
template<>
|
||||
__device__ __always_inline nv_bfloat16 constVal<nv_bfloat16>(float val){
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// B, H, W, C, BLOCK_DIM must be multiple of C
|
||||
template <typename scalar_t, typename vec2_t, int pipeline_stages, int TILE_C, int TILE_THREADS>
|
||||
__global__ void dcn_backward_pipeline_kernel(
|
||||
const int H,
|
||||
const int W,
|
||||
const int G,
|
||||
const int K,
|
||||
const int C,
|
||||
scalar_t* ptr_values,
|
||||
scalar_t* ptr_deformables,
|
||||
scalar_t* ptr_weights,
|
||||
scalar_t* ptr_grad_out,
|
||||
scalar_t* ptr_grad_values,
|
||||
scalar_t* ptr_grad_deformables,
|
||||
scalar_t* ptr_grad_weights
|
||||
) {
|
||||
auto block = cg::this_thread_block();
|
||||
auto self_thread = cg::this_thread();
|
||||
auto tile_threads = cg::tiled_partition<TILE_THREADS>(block);
|
||||
int local_thread_id = block.thread_rank();
|
||||
int local_tile_id = tile_threads.meta_group_rank();
|
||||
int num_local_tiles = tile_threads.meta_group_size();
|
||||
int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id;
|
||||
|
||||
extern __shared__ int shm[];
|
||||
auto GradBuffer = reinterpret_cast<scalar_t*>(shm);
|
||||
scalar_t* Buffer = reinterpret_cast<scalar_t*>(shm) + num_local_tiles*C;
|
||||
if(global_tile_id >= H*W*G) return;
|
||||
|
||||
int bid = block.group_index().y;
|
||||
int gid = global_tile_id % G;
|
||||
int wid = global_tile_id / G % W;
|
||||
int hid = global_tile_id / G / W;
|
||||
int globale_offset = bid*H*W*G*C + global_tile_id*C;
|
||||
cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C);
|
||||
|
||||
int shared_offset[pipeline_stages];
|
||||
for (int s = 0; s < pipeline_stages; ++s) {
|
||||
shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4);
|
||||
}
|
||||
|
||||
auto pipeline = cuda::make_pipeline();
|
||||
const int num_tiles_per_thread = C/TILE_C/TILE_THREADS;
|
||||
|
||||
for(int k=0; k<K; k++) {
|
||||
int offset = bid * K * H * W * G + hid * W * K * G + wid * K * G + gid * K + k;
|
||||
scalar_t x, y, weight;
|
||||
if (tile_threads.thread_rank() == 0) {
|
||||
x = ptr_deformables[offset*2] + fromInt<scalar_t>(wid);
|
||||
y = ptr_deformables[offset*2 + 1] + fromInt<scalar_t>(hid);
|
||||
// x = fromInt<scalar_t>(wid);
|
||||
// y = fromInt<scalar_t>(hid);
|
||||
weight = ptr_weights[offset];
|
||||
}
|
||||
tile_threads.sync();
|
||||
x = tile_threads.shfl(x, 0);
|
||||
y = tile_threads.shfl(y, 0);
|
||||
weight = tile_threads.shfl(weight, 0);
|
||||
|
||||
int floor_x = toInt<scalar_t>(x);
|
||||
int floor_y = toInt<scalar_t>(y);
|
||||
int ceil_x = floor_x + 1;
|
||||
int ceil_y = floor_y + 1;
|
||||
|
||||
|
||||
scalar_t dodx = constVal<scalar_t>(0.0f);
|
||||
scalar_t dody = constVal<scalar_t>(0.0f);
|
||||
scalar_t dodw = constVal<scalar_t>(0.0f);
|
||||
|
||||
int start_c = tile_threads.thread_rank() * (C / TILE_THREADS);
|
||||
|
||||
bool tl_flag = (floor_x >=0) and (floor_x <W) and (floor_y>=0) and (floor_y<H);
|
||||
bool tr_flag = (ceil_x >=0) and (ceil_x <W) and (floor_y>=0) and (floor_y<H);
|
||||
bool bl_flag = (floor_x >=0) and (floor_x <W) and (ceil_y>=0) and (ceil_y<H);
|
||||
bool br_flag = (ceil_x >=0) and (ceil_x <W) and (ceil_y>=0) and (ceil_y<H);
|
||||
|
||||
int tl_global_base = (bid * H * W * G + floor_y * W * G + floor_x * G + gid)*C + start_c;
|
||||
int tr_global_base = (bid * H * W * G + floor_y * W * G + ceil_x * G + gid)*C + start_c;
|
||||
int bl_global_base = (bid * H * W * G + ceil_y * W * G + floor_x * G + gid)*C +start_c;
|
||||
int br_global_base = (bid * H * W * G + ceil_y * W * G + ceil_x * G + gid)*C +start_c;
|
||||
|
||||
|
||||
auto asmem_load_fn = [&](int shm_offset, int hbm_offset, bool flag){
|
||||
if(flag){
|
||||
cuda::memcpy_async(Buffer + shm_offset, ptr_values + hbm_offset,
|
||||
TILE_C * sizeof(scalar_t), pipeline);
|
||||
}else{
|
||||
memset(Buffer+shm_offset, TILE_C, sizeof(scalar_t));
|
||||
}
|
||||
};
|
||||
|
||||
// pipeline-compute&load
|
||||
for (int compute_n = 0, fetch_n=0; compute_n < num_tiles_per_thread; compute_n++) {
|
||||
for (; fetch_n < compute_n + pipeline_stages and fetch_n < num_tiles_per_thread; fetch_n++) {
|
||||
pipeline.producer_acquire();
|
||||
int buffer_offset = shared_offset[fetch_n % pipeline_stages];
|
||||
|
||||
// tl
|
||||
asmem_load_fn(buffer_offset, tl_global_base + fetch_n * TILE_C, tl_flag);
|
||||
// tr
|
||||
asmem_load_fn(buffer_offset+TILE_C, tr_global_base + fetch_n * TILE_C, tr_flag);
|
||||
// bl
|
||||
asmem_load_fn(buffer_offset+TILE_C*2, bl_global_base + fetch_n * TILE_C, bl_flag);
|
||||
// br
|
||||
asmem_load_fn(buffer_offset+TILE_C*3, br_global_base + fetch_n * TILE_C, br_flag);
|
||||
|
||||
pipeline.producer_commit();
|
||||
}
|
||||
pipeline.consumer_wait();
|
||||
int buffer_id = compute_n % pipeline_stages;
|
||||
int ibuffer_offset = shared_offset[buffer_id];
|
||||
int gbuffer_offset = local_tile_id * C + start_c + compute_n * TILE_C;
|
||||
|
||||
for (int j = 0; j < TILE_C; j+=2) {
|
||||
if(tl_flag){
|
||||
// tl
|
||||
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
|
||||
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
|
||||
dodx = dodx + -weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1];
|
||||
dody = dody + -weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
|
||||
{
|
||||
vec2_t vtl_di;
|
||||
vtl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
|
||||
vtl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1];
|
||||
atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(tr_flag){
|
||||
// tr
|
||||
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
|
||||
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
dodx = dodx + weight*(fromInt<scalar_t>(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1];
|
||||
dody = dody + -weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1];
|
||||
{
|
||||
vec2_t vtr_di;
|
||||
vtr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
|
||||
vtr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (fromInt<scalar_t>(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1];
|
||||
atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di);
|
||||
}
|
||||
}
|
||||
|
||||
if(bl_flag){
|
||||
// bl
|
||||
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
|
||||
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodw = dodw + (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
dodx = dodx + -weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
dody = dody + weight*(fromInt<scalar_t>(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
{
|
||||
vec2_t vbl_di;
|
||||
vbl_di.x = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
|
||||
vbl_di.y = weight* (fromInt<scalar_t>(ceil_x) - x) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
|
||||
atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if(br_flag){
|
||||
// tr
|
||||
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
|
||||
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
|
||||
dodw = dodw + (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
dodx = dodx + weight*(y - fromInt<scalar_t>(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
dody = dody + weight*(x - fromInt<scalar_t>(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
|
||||
{
|
||||
vec2_t vbr_di;
|
||||
vbr_di.x = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j];
|
||||
vbr_di.y = weight* (x - fromInt<scalar_t>(floor_x)) * (y - fromInt<scalar_t>(floor_y)) * GradBuffer[gbuffer_offset + j+1];
|
||||
atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di);
|
||||
}
|
||||
}
|
||||
}
|
||||
pipeline.consumer_release();
|
||||
}
|
||||
for (int i = TILE_THREADS>>1; i > 0; i/=2) {
|
||||
dodx = dodx + tile_threads.shfl_down(dodx, i);
|
||||
dody = dody + tile_threads.shfl_down(dody, i);
|
||||
dodw = dodw + tile_threads.shfl_down(dodw, i);
|
||||
}
|
||||
if (tile_threads.thread_rank() == 0) {
|
||||
cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline);
|
||||
cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline);
|
||||
cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
using namespace torch;
|
||||
template<int pipeline_stages, int TILE_C, int TILE_THREADS, int THREADS>
|
||||
void backward(const int B,
|
||||
const int H,
|
||||
const int W,
|
||||
const int G,
|
||||
const int K,
|
||||
const int C,
|
||||
torch::Tensor values,
|
||||
torch::Tensor deformables,
|
||||
torch::Tensor weights,
|
||||
torch::Tensor grad_out,
|
||||
torch::Tensor grad_values,
|
||||
torch::Tensor grad_deformables,
|
||||
torch::Tensor grad_weights
|
||||
) {
|
||||
int num_local_tiles =(THREADS/TILE_THREADS);
|
||||
int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles;
|
||||
dim3 launch_threads_per_block(THREADS);
|
||||
dim3 launch_blocks(num_global_tiles, B);
|
||||
|
||||
int deformable_shm_size = 0;
|
||||
int grad_out_shm_size = num_local_tiles*C;
|
||||
int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS);
|
||||
|
||||
int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size;
|
||||
// printf("shm_size: %d\n", shm_size/512);
|
||||
// printf("pipeline_size: %d\n", pipeline_shm_size/512);
|
||||
// printf("grad_out_size: %d\n", grad_out_shm_size/512);
|
||||
|
||||
|
||||
switch (values.type().scalarType()) {
|
||||
case at::ScalarType::Half:
|
||||
return dcn_backward_pipeline_kernel<half, half2, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(half)>>>(
|
||||
H, W, G, K, C,
|
||||
reinterpret_cast<half*>(values.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(deformables.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(weights.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(grad_out.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(grad_values.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(grad_deformables.data_ptr<at::Half>()),
|
||||
reinterpret_cast<half*>(grad_weights.data_ptr<at::Half>())
|
||||
);
|
||||
// case at::ScalarType::BFloat16:
|
||||
// return dcn_backward_pipeline_kernel<nv_bfloat16, nv_bfloat162, pipeline_stages, TILE_C, TILE_THREADS><<<launch_blocks, launch_threads_per_block, shm_size*sizeof(nv_bfloat16)>>>(
|
||||
// H, W, G, K, C,
|
||||
// reinterpret_cast<nv_bfloat16*>(values.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(deformables.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(weights.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(grad_out.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(grad_values.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(grad_deformables.data_ptr<at::BFloat16>()),
|
||||
// reinterpret_cast<nv_bfloat16*>(grad_weights.data_ptr<at::BFloat16>())
|
||||
// );
|
||||
default:
|
||||
printf("running error");
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, "");
|
||||
m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, "");
|
||||
m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, "");
|
||||
m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, "");
|
||||
m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, "");
|
||||
m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, "");
|
||||
m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, "");
|
||||
m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, "");
|
||||
m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, "");
|
||||
m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, "");
|
||||
m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, "");
|
||||
m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, "");
|
||||
m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, "");
|
||||
m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, "");
|
||||
m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, "");
|
||||
// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, "");
|
||||
// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, "");
|
||||
// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, "");
|
||||
|
||||
m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, "");
|
||||
m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, "");
|
||||
m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, "");
|
||||
m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, "");
|
||||
}
|
||||
289
src/ops/cuda_kernels/bak_forward.cu
Normal file
289
src/ops/cuda_kernels/bak_forward.cu
Normal file
@@ -0,0 +1,289 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector_types.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <cuda_fp16.hpp>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/memcpy_async.h>
|
||||
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)((*ptr_b) * weight);
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)((*ptr_b));
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA>
|
||||
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = 0;
|
||||
ptr_a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
// B, H, W, C, BLOCK_DIM must be multiple of C
|
||||
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
|
||||
__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
|
||||
int work_id = threadIdx.x;
|
||||
int bid = blockIdx.z;
|
||||
int gid = blockIdx.y;
|
||||
int G = gridDim.y;
|
||||
int c_blockid = blockIdx.x;
|
||||
int work_load = (H*W/blockDim.x);
|
||||
|
||||
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
|
||||
// __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
|
||||
math_t register_bufferA[BLOCK_DIM] = {0};
|
||||
int base_c = c_blockid*BLOCK_DIM;
|
||||
|
||||
int num_transfers = BLOCK_DIM;
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
|
||||
for(int j=0; j<num_transfers; j++){
|
||||
if((base_c+j) < C){
|
||||
// __pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
|
||||
math_buffer[job_id][j] = (math_t)*(ptr_value + offset2 +j);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
work_load = (H*W)/blockDim.x;
|
||||
int offset = 0;
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = (work_id*work_load+i);
|
||||
int hid = job_id/W;
|
||||
int wid = job_id%W;
|
||||
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
|
||||
// loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
|
||||
#pragma unroll
|
||||
for(int k=0; k<K; k++){
|
||||
// read weights to register
|
||||
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
|
||||
math_t weight = *(ptr_weights + offset);
|
||||
// read deformables to register
|
||||
offset = offset*2;
|
||||
math_t x = *(ptr_deformables + offset) + wid;
|
||||
math_t y = *(ptr_deformables + offset + 1) + hid;
|
||||
int floor_x = x;
|
||||
int floor_y = y;
|
||||
int ceil_x = floor_x + 1;
|
||||
int ceil_y = floor_y + 1;
|
||||
|
||||
// reset A buffer and top left
|
||||
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
|
||||
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load top right
|
||||
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
|
||||
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
// load bottom left
|
||||
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
|
||||
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load bottom right
|
||||
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
|
||||
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
}
|
||||
// loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
|
||||
#pragma unroll
|
||||
for(int j=0; j<BLOCK_DIM; j++){
|
||||
if((base_c+j) < C){
|
||||
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// B, H, W, C, BLOCK_DIM must be multiple of C
|
||||
template <typename math_t, typename scalar_t, int transfer_length, int K, int L, int BLOCK_DIM>
|
||||
__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
|
||||
int work_id = threadIdx.x;
|
||||
int bid = blockIdx.z;
|
||||
int gid = blockIdx.y;
|
||||
int G = gridDim.y;
|
||||
int c_blockid = blockIdx.x;
|
||||
int work_load = (H*W/blockDim.x);
|
||||
|
||||
__shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
|
||||
__shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
|
||||
math_t register_bufferA[BLOCK_DIM] = {0};
|
||||
int base_c = c_blockid*BLOCK_DIM;
|
||||
|
||||
int num_transfers = BLOCK_DIM/transfer_length;
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
|
||||
for(int j=0; j<num_transfers; j++){
|
||||
if((base_c+j*transfer_length) < C){
|
||||
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
|
||||
}
|
||||
}
|
||||
}
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
work_load = (H*W)/blockDim.x;
|
||||
int offset = 0;
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = (work_id*work_load+i);
|
||||
int hid = job_id/W;
|
||||
int wid = job_id%W;
|
||||
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
|
||||
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
|
||||
#pragma unroll
|
||||
for(int k=0; k<K; k++){
|
||||
// read weights to register
|
||||
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
|
||||
math_t weight = *(ptr_weights + offset);
|
||||
// read deformables to register
|
||||
offset = offset*2;
|
||||
math_t x = *(ptr_deformables + offset) + wid;
|
||||
math_t y = *(ptr_deformables + offset + 1) + hid;
|
||||
int floor_x = x;
|
||||
int floor_y = y;
|
||||
int ceil_x = floor_x + 1;
|
||||
int ceil_y = floor_y + 1;
|
||||
|
||||
// reset A buffer and top left
|
||||
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
|
||||
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load top right
|
||||
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
|
||||
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
// load bottom left
|
||||
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
|
||||
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load bottom right
|
||||
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
|
||||
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
}
|
||||
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
|
||||
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
// int offset1 = job_id*num_transfers;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
|
||||
#pragma unroll
|
||||
for(int j=0; j<num_transfers; j++){
|
||||
if((base_c+j*transfer_length) < C){
|
||||
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int L, int C_BLOCK_DIM, int THREADS>
|
||||
void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
|
||||
|
||||
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
|
||||
dim3 launch_threads_per_block(THREADS);
|
||||
dim3 launch_blocks(NUM_C_BLOCK, G, B);
|
||||
|
||||
switch (value.type().scalarType()) {
|
||||
case at::ScalarType::Half:
|
||||
return dcn_forward_kernel_16<at::Half, at::Half, 4, 9, L, (C_BLOCK_DIM)><<<launch_blocks, launch_threads_per_block>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<at::Half>(),
|
||||
deformables.data_ptr<at::Half>(),
|
||||
weights.data_ptr<at::Half>(),
|
||||
out.data_ptr<at::Half>());
|
||||
case at::ScalarType::BFloat16:
|
||||
return dcn_forward_kernel_16<at::BFloat16, at::BFloat16, 4, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<at::BFloat16>(),
|
||||
deformables.data_ptr<at::BFloat16>(),
|
||||
weights.data_ptr<at::BFloat16>(),
|
||||
out.data_ptr<at::BFloat16>());
|
||||
case at::ScalarType::Float:
|
||||
return dcn_forward_kernel<at::Half, float, 2, 9, L, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<float>(),
|
||||
deformables.data_ptr<float>(),
|
||||
weights.data_ptr<float>(),
|
||||
out.data_ptr<float>());
|
||||
default:
|
||||
printf("running error");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// PyBind11 bindings
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
|
||||
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
|
||||
}
|
||||
309
src/ops/cuda_kernels/forward.cu
Normal file
309
src/ops/cuda_kernels/forward.cu
Normal file
@@ -0,0 +1,309 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector_types.h>
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/OpMathType.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
#include <cuda_fp16.hpp>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/memcpy_async.h>
|
||||
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)(*ptr_a + (*ptr_b) * weight);
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)((*ptr_b) * weight);
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA, typename TB>
|
||||
__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = (TA)((*ptr_b));
|
||||
ptr_a += stride_a;
|
||||
ptr_b += stride_b;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TA>
|
||||
__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
|
||||
#pragma unroll
|
||||
for(int i=0; i<n; i++){
|
||||
*ptr_a = 0;
|
||||
ptr_a += stride;
|
||||
}
|
||||
}
|
||||
|
||||
// B, H, W, C, BLOCK_DIM must be multiple of C
|
||||
template <typename math_t, typename scalar_t, int K, int BLOCK_DIM>
|
||||
__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
|
||||
int work_id = threadIdx.x;
|
||||
int bid = blockIdx.z;
|
||||
int gid = blockIdx.y;
|
||||
int G = gridDim.y;
|
||||
int c_blockid = blockIdx.x;
|
||||
int work_load = (H*W/blockDim.x);
|
||||
|
||||
extern __shared__ int shm[];
|
||||
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
|
||||
|
||||
math_t register_bufferA[BLOCK_DIM] = {0};
|
||||
int base_c = c_blockid*BLOCK_DIM;
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
|
||||
for(int j=0; j<BLOCK_DIM; j++){
|
||||
if((base_c+j) < C){
|
||||
math_buffer[job_id*BLOCK_DIM +j] = (math_t)*(ptr_value + offset2 +j);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
work_load = (H*W)/blockDim.x;
|
||||
int offset = 0;
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = (work_id*work_load+i);
|
||||
int hid = job_id/W;
|
||||
int wid = job_id%W;
|
||||
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
|
||||
#pragma unroll
|
||||
for(int k=0; k<K; k++){
|
||||
// read weights to register
|
||||
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
|
||||
math_t weight = *(ptr_weights + offset);
|
||||
// read deformables to register
|
||||
offset = offset*2;
|
||||
math_t x = *(ptr_deformables + offset) + wid;
|
||||
math_t y = *(ptr_deformables + offset + 1) + hid;
|
||||
int floor_x = x;
|
||||
int floor_y = y;
|
||||
int ceil_x = floor_x + 1;
|
||||
int ceil_y = floor_y + 1;
|
||||
|
||||
// reset A buffer and top left
|
||||
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
|
||||
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load top right
|
||||
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
|
||||
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
// load bottom left
|
||||
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
|
||||
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load bottom right
|
||||
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
|
||||
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
}
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
|
||||
#pragma unroll
|
||||
for(int j=0; j<BLOCK_DIM; j++){
|
||||
if((base_c+j) < C){
|
||||
*(ptr_out + offset2 + j) = (scalar_t)register_bufferA[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// B, H, W, C, BLOCK_DIM must be multiple of C
|
||||
template <typename math_t, typename scalar_t, int transfer_length, int K, int BLOCK_DIM>
|
||||
__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
|
||||
int work_id = threadIdx.x;
|
||||
int bid = blockIdx.z;
|
||||
int gid = blockIdx.y;
|
||||
int G = gridDim.y;
|
||||
int c_blockid = blockIdx.x;
|
||||
int work_load = (H*W/blockDim.x);
|
||||
|
||||
extern __shared__ int shm[];
|
||||
math_t* math_buffer = reinterpret_cast<math_t*>(shm);
|
||||
scalar_t* io_buffer = reinterpret_cast<scalar_t*>(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t);
|
||||
math_t register_bufferA[BLOCK_DIM] = {0};
|
||||
int base_c = c_blockid*BLOCK_DIM;
|
||||
|
||||
int num_transfers = BLOCK_DIM/transfer_length;
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
|
||||
for(int j=0; j<num_transfers; j++){
|
||||
if((base_c+j*transfer_length) < C){
|
||||
__pipeline_memcpy_async((long*)(&math_buffer[job_id]) + j, (long *)ptr_value + offset2 + j, sizeof(long));
|
||||
}
|
||||
}
|
||||
}
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
work_load = (H*W)/blockDim.x;
|
||||
int offset = 0;
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = (work_id*work_load+i);
|
||||
int hid = job_id/W;
|
||||
int wid = job_id%W;
|
||||
loop_reset<math_t>(register_bufferA, 1, BLOCK_DIM);
|
||||
loop_reset<scalar_t>((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
|
||||
#pragma unroll
|
||||
for(int k=0; k<K; k++){
|
||||
// read weights to register
|
||||
offset = bid*K*H*W*G + hid*W*K*G + wid*K*G + gid*K +k;
|
||||
math_t weight = *(ptr_weights + offset);
|
||||
// read deformables to register
|
||||
offset = offset*2;
|
||||
math_t x = *(ptr_deformables + offset) + wid;
|
||||
math_t y = *(ptr_deformables + offset + 1) + hid;
|
||||
int floor_x = x;
|
||||
int floor_y = y;
|
||||
int ceil_x = floor_x + 1;
|
||||
int ceil_y = floor_y + 1;
|
||||
|
||||
// reset A buffer and top left
|
||||
math_t tl_weight = (ceil_x - x)*(ceil_y - y)*weight;
|
||||
if( (0<= floor_y) and (floor_y < H) and (0<= floor_x) and (floor_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load top right
|
||||
math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
|
||||
if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
// load bottom left
|
||||
math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
|
||||
if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
// load bottom right
|
||||
math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
|
||||
if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
|
||||
loop_mul_add<math_t, math_t>(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
|
||||
}
|
||||
|
||||
}
|
||||
loop_load<scalar_t, math_t>((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
|
||||
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for(int i=0; i<work_load; i++){
|
||||
int job_id = work_load*work_id + i;
|
||||
// int offset1 = job_id*num_transfers;
|
||||
int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c)/transfer_length;
|
||||
#pragma unroll
|
||||
for(int j=0; j<num_transfers; j++){
|
||||
if((base_c+j*transfer_length) < C){
|
||||
*((long *)ptr_out + offset2 + j) = *((long *)(&io_buffer[job_id]) +j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int C_BLOCK_DIM, int THREADS>
|
||||
void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
|
||||
|
||||
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
|
||||
dim3 launch_threads_per_block(THREADS);
|
||||
dim3 launch_blocks(NUM_C_BLOCK, G, B);
|
||||
int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half);
|
||||
switch (value.type().scalarType()) {
|
||||
case at::ScalarType::Half:
|
||||
return dcn_forward_kernel_register<at::Half, at::Half, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<at::Half>(),
|
||||
deformables.data_ptr<at::Half>(),
|
||||
weights.data_ptr<at::Half>(),
|
||||
out.data_ptr<at::Half>());
|
||||
case at::ScalarType::Float:
|
||||
return dcn_forward_kernel_register<at::Half, float, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<float>(),
|
||||
deformables.data_ptr<float>(),
|
||||
weights.data_ptr<float>(),
|
||||
out.data_ptr<float>());
|
||||
default:
|
||||
printf("running error");
|
||||
}
|
||||
}
|
||||
|
||||
template<int C_BLOCK_DIM, int THREADS>
|
||||
void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
|
||||
|
||||
int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
|
||||
dim3 launch_threads_per_block(THREADS);
|
||||
dim3 launch_blocks(NUM_C_BLOCK, G, B);
|
||||
int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half);
|
||||
switch (value.type().scalarType()) {
|
||||
case at::ScalarType::Half:
|
||||
return dcn_forward_kernel_pipeline<at::Half, at::Half, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<at::Half>(),
|
||||
deformables.data_ptr<at::Half>(),
|
||||
weights.data_ptr<at::Half>(),
|
||||
out.data_ptr<at::Half>());
|
||||
case at::ScalarType::BFloat16:
|
||||
return dcn_forward_kernel_pipeline<at::BFloat16, at::BFloat16, 4, 9, C_BLOCK_DIM><<<launch_blocks, launch_threads_per_block, shm_size>>>(
|
||||
H, W, C,
|
||||
value.data_ptr<at::BFloat16>(),
|
||||
deformables.data_ptr<at::BFloat16>(),
|
||||
weights.data_ptr<at::BFloat16>(),
|
||||
out.data_ptr<at::BFloat16>());
|
||||
default:
|
||||
printf("running error");
|
||||
}
|
||||
}
|
||||
|
||||
// PyBind11 bindings
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
|
||||
//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward");
|
||||
m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
|
||||
// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
|
||||
}
|
||||
95
src/ops/cuda_kernels/forward.py
Normal file
95
src/ops/cuda_kernels/forward.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1),
|
||||
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
|
||||
triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
|
||||
],
|
||||
key=['B', 'H', 'W', 'G', 'C', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def forward_kernel(
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr, # image_size_h
|
||||
W: tl.constexpr, # image_size_w
|
||||
G: tl.constexpr, # num_channels_per_group
|
||||
C: tl.constexpr, # num_groups
|
||||
K: tl.constexpr, # kernel size
|
||||
input_ptr, # input features [B, H, W, G, C]
|
||||
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
|
||||
weights_ptr, # weights [B, H, W, G, K]
|
||||
out_ptr, # out [B, H, W, G, C]
|
||||
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
wid = pid % W
|
||||
hid = pid // W % H
|
||||
gid = pid // (W * H) % G
|
||||
bid = pid // (W * H * G)
|
||||
|
||||
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
|
||||
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
|
||||
batch_base = bid * H * W * G * C
|
||||
|
||||
for block_base in tl.static_range(0, C, BLOCK_SIZE):
|
||||
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
|
||||
block_mask = (block_offset < C) & id_mask
|
||||
for k in tl.static_range(K):
|
||||
deformable_offset = (common_offset * K + k) * 2
|
||||
|
||||
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
|
||||
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
|
||||
|
||||
floor_x = x.to(tl.int32)
|
||||
floor_y = y.to(tl.int32)
|
||||
ceil_x = floor_x + 1
|
||||
ceil_y = floor_y + 1
|
||||
|
||||
# load top left
|
||||
tl_weight = (ceil_x - x) * (ceil_y - y)
|
||||
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
|
||||
|
||||
# load top right
|
||||
tr_weight = (x - floor_x) * (ceil_y - y)
|
||||
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
|
||||
# load bottom left
|
||||
bl_weight = (ceil_x - x) * (y - floor_y)
|
||||
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
|
||||
# load bottom right
|
||||
br_weight = (x - floor_x) * (y - floor_y)
|
||||
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
|
||||
|
||||
# load dynamic weight and mask
|
||||
weights_offset = common_offset*K + k
|
||||
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
|
||||
|
||||
|
||||
|
||||
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
|
||||
tl_block_input = tl_block_input * tl_weight
|
||||
|
||||
# load top right
|
||||
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
|
||||
tr_block_input = tr_block_input * tr_weight
|
||||
# load bottom left
|
||||
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
|
||||
bl_block_input = bl_block_input * bl_weight
|
||||
# load bottom right
|
||||
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
|
||||
br_block_input = br_block_input * br_weight
|
||||
|
||||
# sampled
|
||||
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
|
||||
|
||||
weighted_sampled_input = sampled_input * weight
|
||||
buffer = buffer + weighted_sampled_input
|
||||
# store to out_ptr
|
||||
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
|
||||
|
||||
126
src/ops/cuda_kernels/function.py
Normal file
126
src/ops/cuda_kernels/function.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import time
|
||||
import dcn_cuda_backward
|
||||
import dcn_cuda_forward
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typing import Any
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
from .forward import forward_kernel
|
||||
|
||||
|
||||
class DCNFunction(Function):
|
||||
BP_FUNCS = [
|
||||
dcn_cuda_backward.backward_p1_c2_tile16_thread128,
|
||||
dcn_cuda_backward.backward_p1_c4_tile16_thread128,
|
||||
dcn_cuda_backward.backward_p2_c2_tile16_thread128,
|
||||
dcn_cuda_backward.backward_p1_c2_tile16_thread256,
|
||||
dcn_cuda_backward.backward_p1_c4_tile16_thread256,
|
||||
dcn_cuda_backward.backward_p2_c2_tile16_thread256,
|
||||
dcn_cuda_backward.backward_p1_c2_tile16_thread384,
|
||||
dcn_cuda_backward.backward_p1_c4_tile16_thread384,
|
||||
dcn_cuda_backward.backward_p2_c2_tile16_thread384,
|
||||
dcn_cuda_backward.backward_p1_c2_tile16_thread512,
|
||||
dcn_cuda_backward.backward_p1_c4_tile16_thread512,
|
||||
dcn_cuda_backward.backward_p2_c2_tile16_thread512,
|
||||
dcn_cuda_backward.backward_p1_c2_tile16_thread768,
|
||||
dcn_cuda_backward.backward_p1_c4_tile16_thread768,
|
||||
dcn_cuda_backward.backward_p2_c2_tile16_thread768,
|
||||
dcn_cuda_backward.backward_p1_c2_tile32_thread128,
|
||||
dcn_cuda_backward.backward_p1_c2_tile32_thread256,
|
||||
dcn_cuda_backward.backward_p1_c2_tile32_thread384,
|
||||
dcn_cuda_backward.backward_p1_c2_tile32_thread512,
|
||||
]
|
||||
FW_FUNCS = [
|
||||
dcn_cuda_forward.dcn_forward_l256_c4,
|
||||
dcn_cuda_forward.dcn_forward_l256_c8,
|
||||
dcn_cuda_forward.dcn_forward_l256_c16,
|
||||
]
|
||||
BP_TABLES = dict()
|
||||
FW_TABLES = dict()
|
||||
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx: Any, values, deformables, weights) -> Any:
|
||||
B, H, W, G, C = values.shape
|
||||
func = DCNFunction.find_fw_funcs(values, deformables, weights)
|
||||
out = torch.zeros_like(values)
|
||||
func(B, G, C, H, W, values, deformables, weights, out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def find_fw_funcs(values, deformables, weights):
|
||||
B, H, W, G, C = values.shape
|
||||
B, H, W, G, K = weights.shape
|
||||
hash_value = 10000 * B + 100 * H + W + 1000 * G
|
||||
if hash_value in DCNFunction.FW_TABLES.keys():
|
||||
return DCNFunction.FW_TABLES[hash_value]
|
||||
print("missing")
|
||||
candicate_func = None
|
||||
min_t = 999.0
|
||||
outs = torch.zeros_like(values)
|
||||
for func in DCNFunction.FW_FUNCS:
|
||||
t = []
|
||||
for i in range(100):
|
||||
torch.cuda.synchronize()
|
||||
start_t = time.time()
|
||||
func(B, G, C, H, W, values, deformables, weights, outs)
|
||||
torch.cuda.synchronize()
|
||||
t.append(time.time() - start_t)
|
||||
t = t[-50:]
|
||||
t = sum(t) / len(t)
|
||||
if t < min_t:
|
||||
min_t = t
|
||||
DCNFunction.FW_TABLES[hash_value] = func
|
||||
candicate_func = func
|
||||
assert candicate_func is not None
|
||||
print(candicate_func)
|
||||
return candicate_func
|
||||
@staticmethod
|
||||
def find_bp_funcs(values, deformables, weights, grad_out):
|
||||
B, H, W, G, C = values.shape
|
||||
B, H, W, G, K = weights.shape
|
||||
hash_value = 10000 * B + 100 * H + W + 1000 * G
|
||||
if hash_value in DCNFunction.BP_TABLES.keys():
|
||||
return DCNFunction.BP_TABLES[hash_value]
|
||||
print("missing")
|
||||
candicate_func = None
|
||||
min_t = 999.0
|
||||
grad_values = torch.zeros_like(values)
|
||||
grad_deformables = torch.zeros_like(deformables)
|
||||
grad_weights = torch.zeros_like(weights)
|
||||
for func in DCNFunction.BP_FUNCS:
|
||||
t = []
|
||||
for i in range(100):
|
||||
torch.cuda.synchronize()
|
||||
start_t = time.time()
|
||||
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
|
||||
torch.cuda.synchronize()
|
||||
t.append(time.time() - start_t)
|
||||
t = t[-50:]
|
||||
t = sum(t) / len(t)
|
||||
if t < min_t:
|
||||
min_t = t
|
||||
DCNFunction.BP_TABLES[hash_value] = func
|
||||
candicate_func = func
|
||||
assert candicate_func is not None
|
||||
print(candicate_func)
|
||||
return candicate_func
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
grad_out = grad_outputs[0]
|
||||
values, deformables, weights = ctx.saved_tensors
|
||||
B, H, W, G, C = values.shape
|
||||
B, H, W, G, K = weights.shape
|
||||
func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out)
|
||||
grad_values = torch.zeros_like(values)
|
||||
grad_deformables = torch.zeros_like(deformables)
|
||||
grad_weights = torch.zeros_like(weights)
|
||||
func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
|
||||
return grad_values, grad_deformables, grad_weights
|
||||
59
src/ops/cuda_kernels/setup.py
Normal file
59
src/ops/cuda_kernels/setup.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
setup(
|
||||
name='dcn_cuda_forward',
|
||||
ext_modules=[
|
||||
CUDAExtension('dcn_cuda_forward', ['./forward.cu',],
|
||||
extra_compile_args={'cxx': [], 'nvcc': [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
]}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
}
|
||||
)
|
||||
|
||||
setup(
|
||||
name='dcn_cuda_backward',
|
||||
ext_modules=[
|
||||
CUDAExtension('dcn_cuda_backward', ['./backward.cu',],
|
||||
extra_compile_args={'cxx': [], 'nvcc': [
|
||||
"-DCUDA_HAS_FP16=1",
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
"--use_fast_math",
|
||||
"-O3",
|
||||
]}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# setup(
|
||||
# name='mycuda',
|
||||
# ext_modules=[
|
||||
# CUDAExtension('mycuda', ['./backward.cu',],
|
||||
# extra_compile_args={'cxx': [], 'nvcc': [
|
||||
# "-O3",
|
||||
# "-DCUDA_HAS_FP16=1",
|
||||
# "-D__CUDA_NO_HALF_OPERATORS__",
|
||||
# "-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
# "-D__CUDA_NO_HALF2_OPERATORS__",
|
||||
# ]}
|
||||
# ),
|
||||
# ],
|
||||
# cmdclass={
|
||||
# 'build_ext': BuildExtension
|
||||
# }
|
||||
# )
|
||||
0
src/ops/triton_kernels/__init__.py
Normal file
0
src/ops/triton_kernels/__init__.py
Normal file
124
src/ops/triton_kernels/backward.py
Normal file
124
src/ops/triton_kernels/backward.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
|
||||
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
|
||||
],
|
||||
key=['B', 'H', 'W', 'G', 'C', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def backward_kernel(
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr, # image_size_h
|
||||
W: tl.constexpr, # image_size_w
|
||||
G: tl.constexpr, # num_groups
|
||||
C: tl.constexpr, # num_channels_per_group
|
||||
K: tl.constexpr, # kernel size
|
||||
input_ptr, # input features [B, H, W, G, C]
|
||||
deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
|
||||
weights_ptr, # weights [B, H, W, G, K]
|
||||
grad_ptr, # out [B, H, W, G, C]
|
||||
grad_input_ptr, # input features [B, H, W, G, C]
|
||||
grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
|
||||
grad_weights_ptr, # weights [B, H, W, G, K]
|
||||
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
|
||||
):
|
||||
|
||||
pid = tl.program_id(0)
|
||||
wid = pid % W
|
||||
hid = pid // W % H
|
||||
gid = pid // (W * H) % G
|
||||
bid = pid // (W * H * G)
|
||||
|
||||
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
|
||||
|
||||
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
|
||||
batch_base = bid * H * W * G * C
|
||||
for k in tl.static_range(K):
|
||||
# load dynamic weight and mask
|
||||
weights_offset = common_offset*K + k
|
||||
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
|
||||
dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
|
||||
dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
|
||||
dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
|
||||
deformable_offset = (common_offset * K + k)*2
|
||||
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
|
||||
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
|
||||
for block_base in tl.static_range(0, C, BLOCK_SIZE):
|
||||
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
|
||||
block_mask = (block_offset < C) & id_mask
|
||||
grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
|
||||
dods = weight*grad
|
||||
|
||||
floor_x = x.to(tl.int32)
|
||||
floor_y = y.to(tl.int32)
|
||||
ceil_x = floor_x + 1
|
||||
ceil_y = floor_y + 1
|
||||
|
||||
# load top left
|
||||
tl_weight = (ceil_x - x) * (ceil_y - y)
|
||||
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
|
||||
tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
|
||||
tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
|
||||
tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
|
||||
dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
|
||||
dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
|
||||
dodw = dodw + tl_block_input_dot_grad * tl_weight
|
||||
|
||||
dodtl = dods * tl_weight
|
||||
tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
|
||||
|
||||
|
||||
# load top right
|
||||
tr_weight = (x - floor_x) * (ceil_y - y)
|
||||
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
|
||||
tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
|
||||
tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
|
||||
tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
|
||||
dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
|
||||
dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
|
||||
dodw = dodw + tr_block_input_dot_grad*tr_weight
|
||||
|
||||
dodtr = dods * tr_weight
|
||||
tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
|
||||
|
||||
|
||||
# load bottom left
|
||||
bl_weight = (ceil_x - x) * (y - floor_y)
|
||||
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
|
||||
bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
|
||||
bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
|
||||
bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
|
||||
dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
|
||||
dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
|
||||
dodw = dodw + bl_block_input_dot_grad*bl_weight
|
||||
|
||||
dodbl = dods * bl_weight
|
||||
tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
|
||||
|
||||
|
||||
# load bottom right
|
||||
br_weight = (x - floor_x) * (y - floor_y)
|
||||
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
|
||||
br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
|
||||
br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
|
||||
br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
|
||||
|
||||
dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
|
||||
dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
|
||||
dodw = dodw + br_block_input_dot_grad*br_weight
|
||||
|
||||
dodbr = dods * br_weight
|
||||
tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
|
||||
dodx = dodx * weight
|
||||
dody = dody * weight
|
||||
tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
|
||||
tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
|
||||
tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
94
src/ops/triton_kernels/forward.py
Normal file
94
src/ops/triton_kernels/forward.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
|
||||
# triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
|
||||
],
|
||||
key=['B', 'H', 'W', 'G', 'C', 'K'],
|
||||
)
|
||||
@triton.jit
|
||||
def forward_kernel(
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr, # image_size_h
|
||||
W: tl.constexpr, # image_size_w
|
||||
G: tl.constexpr, # num_channels_per_group
|
||||
C: tl.constexpr, # num_groups
|
||||
K: tl.constexpr, # kernel size
|
||||
input_ptr, # input features [B, H, W, G, C]
|
||||
deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
|
||||
weights_ptr, # weights [B, H, W, G, K]
|
||||
out_ptr, # out [B, H, W, G, C]
|
||||
BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
wid = pid % W
|
||||
hid = pid // W % H
|
||||
gid = pid // (W * H) % G
|
||||
bid = pid // (W * H * G)
|
||||
|
||||
id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
|
||||
common_offset = bid*H*W*G + hid*W*G + wid*G + gid
|
||||
batch_base = bid * H * W * G * C
|
||||
|
||||
for block_base in tl.static_range(0, C, BLOCK_SIZE):
|
||||
buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
block_offset = tl.arange(0, BLOCK_SIZE) + block_base
|
||||
block_mask = (block_offset < C) & id_mask
|
||||
for k in tl.static_range(K):
|
||||
deformable_offset = (common_offset * K + k) * 2
|
||||
|
||||
x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
|
||||
y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
|
||||
|
||||
floor_x = x.to(tl.int32)
|
||||
floor_y = y.to(tl.int32)
|
||||
ceil_x = floor_x + 1
|
||||
ceil_y = floor_y + 1
|
||||
|
||||
# load top left
|
||||
tl_weight = (ceil_x - x) * (ceil_y - y)
|
||||
tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
|
||||
|
||||
# load top right
|
||||
tr_weight = (x - floor_x) * (ceil_y - y)
|
||||
tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
|
||||
# load bottom left
|
||||
bl_weight = (ceil_x - x) * (y - floor_y)
|
||||
bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
|
||||
# load bottom right
|
||||
br_weight = (x - floor_x) * (y - floor_y)
|
||||
br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
|
||||
br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
|
||||
|
||||
# load dynamic weight and mask
|
||||
weights_offset = common_offset*K + k
|
||||
weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
|
||||
|
||||
|
||||
|
||||
tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
|
||||
tl_block_input = tl_block_input * tl_weight
|
||||
|
||||
# load top right
|
||||
tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
|
||||
tr_block_input = tr_block_input * tr_weight
|
||||
# load bottom left
|
||||
bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
|
||||
bl_block_input = bl_block_input * bl_weight
|
||||
# load bottom right
|
||||
br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
|
||||
br_block_input = br_block_input * br_weight
|
||||
|
||||
# sampled
|
||||
sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
|
||||
|
||||
weighted_sampled_input = sampled_input * weight
|
||||
buffer = buffer + weighted_sampled_input
|
||||
# store to out_ptr
|
||||
tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
|
||||
|
||||
48
src/ops/triton_kernels/function.py
Normal file
48
src/ops/triton_kernels/function.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import triton
|
||||
from typing import Any
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
|
||||
from .forward import forward_kernel
|
||||
from .backward import backward_kernel
|
||||
|
||||
|
||||
|
||||
class DCNFunction(Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx: Any, inputs, deformables, weights) -> Any:
|
||||
B, H, W, G, C = inputs.shape
|
||||
_, _, _, _, K, _ = deformables.shape
|
||||
out = torch.zeros_like(inputs)
|
||||
grid = lambda META: (B * H * W * G,)
|
||||
|
||||
forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
|
||||
ctx.save_for_backward(inputs, deformables, weights)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
||||
grad_output = grad_outputs[0].contiguous()
|
||||
|
||||
inputs, deformables, weights = ctx.saved_tensors
|
||||
B, H, W, G, C = inputs.shape
|
||||
_, _, _, _, K, _ = deformables.shape
|
||||
|
||||
grad_inputs = torch.zeros_like(inputs)
|
||||
grad_deformables = torch.zeros_like(deformables)
|
||||
grad_weights = torch.zeros_like(weights)
|
||||
grid = lambda META: (B * H * W * G,)
|
||||
backward_kernel[grid](
|
||||
B, H, W, G, C, K,
|
||||
inputs,
|
||||
deformables,
|
||||
weights,
|
||||
grad_output,
|
||||
grad_inputs,
|
||||
grad_deformables,
|
||||
grad_weights,
|
||||
)
|
||||
return (grad_inputs, grad_deformables, grad_weights)
|
||||
0
src/plugins/__init__.py
Normal file
0
src/plugins/__init__.py
Normal file
70
src/plugins/bd_env.py
Normal file
70
src/plugins/bd_env.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
import os
|
||||
import socket
|
||||
from typing_extensions import override
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.plugins.environments.lightning import LightningEnvironment
|
||||
|
||||
|
||||
class BDEnvironment(LightningEnvironment):
|
||||
pass
|
||||
# def __init__(self) -> None:
|
||||
# super().__init__()
|
||||
# self._global_rank: int = 0
|
||||
# self._world_size: int = 1
|
||||
#
|
||||
# @property
|
||||
# @override
|
||||
# def creates_processes_externally(self) -> bool:
|
||||
# """Returns whether the cluster creates the processes or not.
|
||||
#
|
||||
# If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
|
||||
# process launcher/job scheduler and Lightning will not launch new processes.
|
||||
#
|
||||
# """
|
||||
# return "LOCAL_RANK" in os.environ
|
||||
#
|
||||
# @staticmethod
|
||||
# @override
|
||||
# def detect() -> bool:
|
||||
# assert "ARNOLD_WORKER_0_HOST" in os.environ.keys()
|
||||
# assert "ARNOLD_WORKER_0_PORT" in os.environ.keys()
|
||||
# return True
|
||||
#
|
||||
# @override
|
||||
# def world_size(self) -> int:
|
||||
# return self._world_size
|
||||
#
|
||||
# @override
|
||||
# def set_world_size(self, size: int) -> None:
|
||||
# self._world_size = size
|
||||
#
|
||||
# @override
|
||||
# def global_rank(self) -> int:
|
||||
# return self._global_rank
|
||||
#
|
||||
# @override
|
||||
# def set_global_rank(self, rank: int) -> None:
|
||||
# self._global_rank = rank
|
||||
# rank_zero_only.rank = rank
|
||||
#
|
||||
# @override
|
||||
# def local_rank(self) -> int:
|
||||
# return int(os.environ.get("LOCAL_RANK", 0))
|
||||
#
|
||||
# @override
|
||||
# def node_rank(self) -> int:
|
||||
# return int(os.environ.get("ARNOLD_ID"))
|
||||
#
|
||||
# @override
|
||||
# def teardown(self) -> None:
|
||||
# if "WORLD_SIZE" in os.environ:
|
||||
# del os.environ["WORLD_SIZE"]
|
||||
#
|
||||
# @property
|
||||
# def main_address(self) -> str:
|
||||
# return os.environ.get("ARNOLD_WORKER_0_HOST")
|
||||
#
|
||||
# @property
|
||||
# def main_port(self) -> int:
|
||||
# return int(os.environ.get("ARNOLD_WORKER_0_PORT"))
|
||||
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
13
src/utils/copy.py
Normal file
13
src/utils/copy.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_params(src_model, dst_model):
|
||||
for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
|
||||
dst_param.data.copy_(src_param.data)
|
||||
|
||||
@torch.no_grad()
|
||||
def swap_tensors(tensor1, tensor2):
|
||||
tmp = torch.empty_like(tensor1)
|
||||
tmp.copy_(tensor1)
|
||||
tensor1.copy_(tensor2)
|
||||
tensor2.copy_(tmp)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user