diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/DDT.iml b/.idea/DDT.iml
new file mode 100644
index 0000000..d0876a7
--- /dev/null
+++ b/.idea/DDT.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..b19c6c8
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/configs/repa_flatten_condit22_fixt_xl.yaml b/configs/repa_flatten_condit22_fixt_xl.yaml
new file mode 100644
index 0000000..6f3a6ce
--- /dev/null
+++ b/configs/repa_flatten_condit22_fixt_xl.yaml
@@ -0,0 +1,108 @@
+# lightning.pytorch==2.4.0
+seed_everything: true
+tags:
+ exp: &exp repa_flatten_condit22_dit6_fixt_xl
+torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
+huggingface_cache_dir: null
+trainer:
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
+ accelerator: auto
+ strategy: auto
+ devices: auto
+ num_nodes: 1
+ precision: bf16-mixed
+ logger:
+ class_path: lightning.pytorch.loggers.WandbLogger
+ init_args:
+ project: universal_flow
+ name: *exp
+ num_sanity_val_steps: 0
+ max_steps: 4000000
+ val_check_interval: 4000000
+ check_val_every_n_epoch: null
+ log_every_n_steps: 50
+ deterministic: null
+ inference_mode: true
+ use_distributed_sampler: false
+ callbacks:
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
+ init_args:
+ every_n_train_steps: 10000
+ save_top_k: -1
+ save_last: true
+ - class_path: src.callbacks.save_images.SaveImagesHook
+ init_args:
+ save_dir: val
+ plugins:
+ - src.plugins.bd_env.BDEnvironment
+model:
+ vae:
+ class_path: src.models.vae.LatentVAE
+ init_args:
+ precompute: true
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
+ denoiser:
+ class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
+ init_args:
+ in_channels: 4
+ patch_size: 2
+ num_groups: 16
+ hidden_size: &hidden_dim 1152
+ num_blocks: 28
+ num_cond_blocks: 22
+ num_classes: 1000
+ conditioner:
+ class_path: src.models.conditioner.LabelConditioner
+ init_args:
+ null_class: 1000
+ diffusion_trainer:
+ class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
+ init_args:
+ lognorm_t: true
+ encoder_weight_path: dinov2_vitb14
+ align_layer: 8
+ proj_denoiser_dim: *hidden_dim
+ proj_hidden_dim: *hidden_dim
+ proj_encoder_dim: 768
+ scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
+ diffusion_sampler:
+ class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
+ init_args:
+ num_steps: 250
+ guidance: 2.0
+ timeshift: 1.5
+ state_refresh_rate: 1
+ guidance_interval_min: 0.3
+ guidance_interval_max: 1.0
+ scheduler: *scheduler
+ w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
+ last_step: 0.04
+ step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
+ ema_tracker:
+ class_path: src.callbacks.simple_ema.SimpleEMA
+ init_args:
+ decay: 0.9999
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: 1e-4
+ betas:
+ - 0.9
+ - 0.95
+ weight_decay: 0.0
+data:
+ train_dataset: imagenet256
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
+ train_image_size: 256
+ train_batch_size: 16
+ eval_max_num_instances: 50000
+ pred_batch_size: 64
+ pred_num_workers: 4
+ pred_seeds: null
+ pred_selected_classes: null
+ num_classes: 1000
+ latent_shape:
+ - 4
+ - 32
+ - 32
\ No newline at end of file
diff --git a/configs/repa_flatten_condit22_fixt_xl512.yaml b/configs/repa_flatten_condit22_fixt_xl512.yaml
new file mode 100644
index 0000000..ce59338
--- /dev/null
+++ b/configs/repa_flatten_condit22_fixt_xl512.yaml
@@ -0,0 +1,108 @@
+# lightning.pytorch==2.4.0
+seed_everything: true
+tags:
+ exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
+torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
+huggingface_cache_dir: null
+trainer:
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
+ accelerator: auto
+ strategy: auto
+ devices: auto
+ num_nodes: 1
+ precision: bf16-mixed
+ logger:
+ class_path: lightning.pytorch.loggers.WandbLogger
+ init_args:
+ project: universal_flow
+ name: *exp
+ num_sanity_val_steps: 0
+ max_steps: 4000000
+ val_check_interval: 4000000
+ check_val_every_n_epoch: null
+ log_every_n_steps: 50
+ deterministic: null
+ inference_mode: true
+ use_distributed_sampler: false
+ callbacks:
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
+ init_args:
+ every_n_train_steps: 10000
+ save_top_k: -1
+ save_last: true
+ - class_path: src.callbacks.save_images.SaveImagesHook
+ init_args:
+ save_dir: val
+ plugins:
+ - src.plugins.bd_env.BDEnvironment
+model:
+ vae:
+ class_path: src.models.vae.LatentVAE
+ init_args:
+ precompute: true
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
+ denoiser:
+ class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT
+ init_args:
+ in_channels: 4
+ patch_size: 2
+ num_groups: 16
+ hidden_size: &hidden_dim 1152
+ num_blocks: 28
+ num_cond_blocks: 22
+ num_classes: 1000
+ conditioner:
+ class_path: src.models.conditioner.LabelConditioner
+ init_args:
+ null_class: 1000
+ diffusion_trainer:
+ class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
+ init_args:
+ lognorm_t: true
+ encoder_weight_path: dinov2_vitb14
+ align_layer: 8
+ proj_denoiser_dim: *hidden_dim
+ proj_hidden_dim: *hidden_dim
+ proj_encoder_dim: 768
+ scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
+ diffusion_sampler:
+ class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
+ init_args:
+ num_steps: 250
+ guidance: 3.0
+ state_refresh_rate: 1
+ guidance_interval_min: 0.3
+ guidance_interval_max: 1.0
+ timeshift: 1.0
+ last_step: 0.04
+ scheduler: *scheduler
+ w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
+ step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
+ ema_tracker:
+ class_path: src.callbacks.simple_ema.SimpleEMA
+ init_args:
+ decay: 0.9999
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: 1e-4
+ betas:
+ - 0.9
+ - 0.95
+ weight_decay: 0.0
+data:
+ train_dataset: imagenet512
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
+ train_image_size: 512
+ train_batch_size: 16
+ eval_max_num_instances: 50000
+ pred_batch_size: 32
+ pred_num_workers: 4
+ pred_seeds: null
+ pred_selected_classes: null
+ num_classes: 1000
+ latent_shape:
+ - 4
+ - 64
+ - 64
\ No newline at end of file
diff --git a/configs/repa_flatten_dit_fixt_large.yaml b/configs/repa_flatten_dit_fixt_large.yaml
new file mode 100644
index 0000000..7dc9611
--- /dev/null
+++ b/configs/repa_flatten_dit_fixt_large.yaml
@@ -0,0 +1,99 @@
+# lightning.pytorch==2.4.0
+seed_everything: true
+tags:
+ exp: &exp repa_flatten_dit_fixt_large
+torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
+huggingface_cache_dir: null
+trainer:
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
+ accelerator: auto
+ strategy: auto
+ devices: auto
+ num_nodes: 1
+ precision: bf16-mixed
+ logger:
+ class_path: lightning.pytorch.loggers.WandbLogger
+ init_args:
+ project: universal_flow
+ name: *exp
+ num_sanity_val_steps: 0
+ max_steps: 400000
+ val_check_interval: 100000
+ check_val_every_n_epoch: null
+ log_every_n_steps: 50
+ deterministic: null
+ inference_mode: true
+ use_distributed_sampler: false
+ callbacks:
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
+ init_args:
+ every_n_train_steps: 10000
+ save_top_k: -1
+ save_last: true
+ - class_path: src.callbacks.save_images.SaveImagesHook
+ init_args:
+ save_dir: val
+ plugins:
+ - src.plugins.bd_env.BDEnvironment
+model:
+ vae:
+ class_path: src.models.vae.LatentVAE
+ init_args:
+ precompute: true
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
+ denoiser:
+ class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
+ init_args:
+ in_channels: 4
+ patch_size: 2
+ num_groups: 16
+ hidden_size: &hidden_dim 1024
+ num_blocks: 24
+ num_classes: 1000
+ conditioner:
+ class_path: src.models.conditioner.LabelConditioner
+ init_args:
+ null_class: 1000
+ diffusion_trainer:
+ class_path: src.diffusion.flow_matching.training_repa.REPATrainer
+ init_args:
+ lognorm_t: true
+ encoder_weight_path: dinov2_vitb14
+ align_layer: 8
+ proj_denoiser_dim: *hidden_dim
+ proj_hidden_dim: *hidden_dim
+ proj_encoder_dim: 768
+ scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
+ diffusion_sampler:
+ class_path: src.diffusion.flow_matching.sampling.EulerSampler
+ init_args:
+ num_steps: 250
+ guidance: 1.00
+ scheduler: *scheduler
+ w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
+ step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
+ ema_tracker:
+ class_path: src.callbacks.simple_ema.SimpleEMA
+ init_args:
+ decay: 0.9999
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: 1e-4
+ weight_decay: 0.0
+data:
+ train_dataset: imagenet256
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
+ train_image_size: 256
+ train_batch_size: 32
+ eval_max_num_instances: 50000
+ pred_batch_size: 64
+ pred_num_workers: 4
+ pred_seeds: null
+ pred_selected_classes: null
+ num_classes: 1000
+ latent_shape:
+ - 4
+ - 32
+ - 32
\ No newline at end of file
diff --git a/configs/repa_flatten_dit_fixt_xl.yaml b/configs/repa_flatten_dit_fixt_xl.yaml
new file mode 100644
index 0000000..2bc8606
--- /dev/null
+++ b/configs/repa_flatten_dit_fixt_xl.yaml
@@ -0,0 +1,99 @@
+# lightning.pytorch==2.4.0
+seed_everything: true
+tags:
+ exp: &exp repa_flatten_dit_fixt_xl
+torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
+huggingface_cache_dir: null
+trainer:
+ default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
+ accelerator: auto
+ strategy: auto
+ devices: auto
+ num_nodes: 1
+ precision: bf16-mixed
+ logger:
+ class_path: lightning.pytorch.loggers.WandbLogger
+ init_args:
+ project: universal_flow
+ name: *exp
+ num_sanity_val_steps: 0
+ max_steps: 400000
+ val_check_interval: 100000
+ check_val_every_n_epoch: null
+ log_every_n_steps: 50
+ deterministic: null
+ inference_mode: true
+ use_distributed_sampler: false
+ callbacks:
+ - class_path: src.callbacks.model_checkpoint.CheckpointHook
+ init_args:
+ every_n_train_steps: 10000
+ save_top_k: -1
+ save_last: true
+ - class_path: src.callbacks.save_images.SaveImagesHook
+ init_args:
+ save_dir: val
+ plugins:
+ - src.plugins.bd_env.BDEnvironment
+model:
+ vae:
+ class_path: src.models.vae.LatentVAE
+ init_args:
+ precompute: true
+ weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
+ denoiser:
+ class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT
+ init_args:
+ in_channels: 4
+ patch_size: 2
+ num_groups: 16
+ hidden_size: &hidden_dim 1152
+ num_blocks: 28
+ num_classes: 1000
+ conditioner:
+ class_path: src.models.conditioner.LabelConditioner
+ init_args:
+ null_class: 1000
+ diffusion_trainer:
+ class_path: src.diffusion.flow_matching.training_repa.REPATrainer
+ init_args:
+ lognorm_t: true
+ encoder_weight_path: dinov2_vitb14
+ align_layer: 8
+ proj_denoiser_dim: *hidden_dim
+ proj_hidden_dim: *hidden_dim
+ proj_encoder_dim: 768
+ scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler
+ diffusion_sampler:
+ class_path: src.diffusion.flow_matching.sampling.EulerSampler
+ init_args:
+ num_steps: 250
+ guidance: 1.00
+ scheduler: *scheduler
+ w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler
+ guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
+ step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn
+ ema_tracker:
+ class_path: src.callbacks.simple_ema.SimpleEMA
+ init_args:
+ decay: 0.9999
+ optimizer:
+ class_path: torch.optim.AdamW
+ init_args:
+ lr: 1e-4
+ weight_decay: 0.0
+data:
+ train_dataset: imagenet256
+ train_root: /mnt/bn/wangshuai6/data/ImageNet/train
+ train_image_size: 256
+ train_batch_size: 32
+ eval_max_num_instances: 50000
+ pred_batch_size: 64
+ pred_num_workers: 4
+ pred_seeds: null
+ pred_selected_classes: null
+ num_classes: 1000
+ latent_shape:
+ - 4
+ - 32
+ - 32
\ No newline at end of file
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..85bf7cb
--- /dev/null
+++ b/main.py
@@ -0,0 +1,86 @@
+import time
+from typing import Any, Union
+
+import pylab as pl
+
+from src.utils.patch_bugs import *
+
+import os
+import torch
+from lightning import Trainer, LightningModule
+from src.lightning_data import DataModule
+from src.lightning_model import LightningModel
+from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback
+
+import logging
+logger = logging.getLogger("lightning.pytorch")
+# log_path = os.path.join( f"log.txt")
+# logger.addHandler(logging.FileHandler(log_path))
+
+class ReWriteRootSaveConfigCallback(SaveConfigCallback):
+ def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
+ stamp = time.strftime('%y%m%d%H%M')
+ file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml")
+ self.parser.save(
+ self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
+ )
+
+
+class ReWriteRootDirCli(LightningCLI):
+ def before_instantiate_classes(self) -> None:
+ super().before_instantiate_classes()
+ config_trainer = self._get(self.config, "trainer", default={})
+
+ # predict path & logger check
+ if self.subcommand == "predict":
+ config_trainer.logger = None
+
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
+ class TagsClass:
+ def __init__(self, exp:str):
+ ...
+ parser.add_class_arguments(TagsClass, nested_key="tags")
+
+ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
+ super().add_default_arguments_to_parser(parser)
+ parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),)
+ parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),)
+
+ def instantiate_trainer(self, **kwargs: Any) -> Trainer:
+ config_trainer = self._get(self.config_init, "trainer", default={})
+ default_root_dir = config_trainer.get("default_root_dir", None)
+
+ if default_root_dir is None:
+ default_root_dir = os.path.join(os.getcwd(), "workdirs")
+
+ dirname = ""
+ for v, k in self._get(self.config, "tags", default={}).items():
+ dirname += f"{v}_{k}"
+ default_root_dir = os.path.join(default_root_dir, dirname)
+ is_resume = self._get(self.config_init, "ckpt_path", default=None)
+ if os.path.exists(default_root_dir) and "debug" not in default_root_dir:
+ if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume:
+ raise FileExistsError(f"{default_root_dir} already exists")
+
+ config_trainer.default_root_dir = default_root_dir
+ trainer = super().instantiate_trainer(**kwargs)
+ if trainer.is_global_zero:
+ os.makedirs(default_root_dir, exist_ok=True)
+ return trainer
+
+ def instantiate_classes(self) -> None:
+ torch_hub_dir = self._get(self.config, "torch_hub_dir")
+ huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir")
+ if huggingface_cache_dir is not None:
+ os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir
+ if torch_hub_dir is not None:
+ os.environ["TORCH_HOME"] = torch_hub_dir
+ torch.hub.set_dir(torch_hub_dir)
+ super().instantiate_classes()
+
+if __name__ == "__main__":
+
+ cli = ReWriteRootDirCli(LightningModel, DataModule,
+ auto_configure_optimizers=False,
+ save_config_callback=ReWriteRootSaveConfigCallback,
+ save_config_kwargs={"overwrite": True})
\ No newline at end of file
diff --git a/main.sh b/main.sh
new file mode 100644
index 0000000..39f803e
--- /dev/null
+++ b/main.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+export NCCL_HOSTID=${MY_POD_NAME}
+export MASTER_ADDR=${ARNOLD_WORKER_0_HOST}
+export MASTER_PORT=${ARNOLD_WORKER_0_PORT}
+export NODE_RANK=${ARNOLD_ID}
+export NUM_NODES=${ARNOLD_WORKER_NUM}
+
+python3 main.py fit -c $1 --trainer.num_nodes $NUM_NODES
+# for pid in $(ps -ef | grep "yaml" | grep -v "grep" | awk '{print $2}'); do kill -9 $pid; done
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..9061a84
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,3 @@
+lightning==2.5.0.post0
+omegaconf==2.3.0
+jsonargparse[signatures]>=4.27.7
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/callbacks/grad.py b/src/callbacks/grad.py
new file mode 100644
index 0000000..f9155b6
--- /dev/null
+++ b/src/callbacks/grad.py
@@ -0,0 +1,22 @@
+import torch
+import lightning.pytorch as pl
+from lightning.pytorch.utilities import grad_norm
+from torch.optim import Optimizer
+
+class GradientMonitor(pl.Callback):
+ """Logs the gradient norm"""
+
+ def __init__(self, norm_type: int = 2):
+ norm_type = float(norm_type)
+ if norm_type <= 0:
+ raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}")
+ self.norm_type = norm_type
+
+ def on_before_optimizer_step(
+ self, trainer: "pl.Trainer",
+ pl_module: "pl.LightningModule",
+ optimizer: Optimizer
+ ) -> None:
+ norms = grad_norm(pl_module, norm_type=self.norm_type)
+ max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max()
+ pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]})
\ No newline at end of file
diff --git a/src/callbacks/model_checkpoint.py b/src/callbacks/model_checkpoint.py
new file mode 100644
index 0000000..019454e
--- /dev/null
+++ b/src/callbacks/model_checkpoint.py
@@ -0,0 +1,25 @@
+import os.path
+from typing import Optional, Dict, Any
+
+import lightning.pytorch as pl
+from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
+from soupsieve.util import lower
+
+
+class CheckpointHook(ModelCheckpoint):
+ """Save checkpoint with only the incremental part of the model"""
+ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
+ self.dirpath = trainer.default_root_dir
+ self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt")
+ pl_module.strict_loading = False
+
+ def on_save_checkpoint(
+ self, trainer: "pl.Trainer",
+ pl_module: "pl.LightningModule",
+ checkpoint: Dict[str, Any]
+ ) -> None:
+ del checkpoint["callbacks"]
+
+ # def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None:
+ # if not "debug" in self.exception_ckpt_path:
+ # trainer.save_checkpoint(self.exception_ckpt_path)
\ No newline at end of file
diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py
new file mode 100644
index 0000000..c6cd32b
--- /dev/null
+++ b/src/callbacks/save_images.py
@@ -0,0 +1,105 @@
+import lightning.pytorch as pl
+from lightning.pytorch import Callback
+
+
+import os.path
+import numpy
+from PIL import Image
+from typing import Sequence, Any, Dict
+from concurrent.futures import ThreadPoolExecutor
+
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+from lightning_utilities.core.rank_zero import rank_zero_info
+
+def process_fn(image, path):
+ Image.fromarray(image).save(path)
+
+class SaveImagesHook(Callback):
+ def __init__(self, save_dir="val", max_save_num=0, compressed=True):
+ self.save_dir = save_dir
+ self.max_save_num = max_save_num
+ self.compressed = compressed
+
+ def save_start(self, target_dir):
+ self.target_dir = target_dir
+ self.executor_pool = ThreadPoolExecutor(max_workers=8)
+ if not os.path.exists(self.target_dir):
+ os.makedirs(self.target_dir, exist_ok=True)
+ else:
+ if os.listdir(target_dir) and "debug" not in str(target_dir):
+ raise FileExistsError(f'{self.target_dir} already exists and not empty!')
+ self.samples = []
+ self._have_saved_num = 0
+ rank_zero_info(f"Save images to {self.target_dir}")
+
+ def save_image(self, images, filenames):
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
+ for sample, filename in zip(images, filenames):
+ if isinstance(filename, Sequence):
+ filename = filename[0]
+ path = f'{self.target_dir}/{filename}'
+ if self._have_saved_num >= self.max_save_num:
+ break
+ self.executor_pool.submit(process_fn, sample, path)
+ self._have_saved_num += 1
+
+ def process_batch(
+ self,
+ trainer: "pl.Trainer",
+ pl_module: "pl.LightningModule",
+ samples: STEP_OUTPUT,
+ batch: Any,
+ ) -> None:
+ b, c, h, w = samples.shape
+ xT, y, metadata = batch
+ all_samples = pl_module.all_gather(samples).view(-1, c, h, w)
+ self.save_image(samples, metadata)
+ if trainer.is_global_zero:
+ all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy()
+ self.samples.append(all_samples)
+
+ def save_end(self):
+ if self.compressed and len(self.samples) > 0:
+ samples = numpy.concatenate(self.samples)
+ numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples)
+ self.executor_pool.shutdown(wait=True)
+ self.samples = []
+ self.target_dir = None
+ self._have_saved_num = 0
+ self.executor_pool = None
+
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}")
+ self.save_start(target_dir)
+
+ def on_validation_batch_end(
+ self,
+ trainer: "pl.Trainer",
+ pl_module: "pl.LightningModule",
+ outputs: STEP_OUTPUT,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ return self.process_batch(trainer, pl_module, outputs, batch)
+
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ self.save_end()
+
+ def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict")
+ self.save_start(target_dir)
+
+ def on_predict_batch_end(
+ self,
+ trainer: "pl.Trainer",
+ pl_module: "pl.LightningModule",
+ samples: Any,
+ batch: Any,
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> None:
+ return self.process_batch(trainer, pl_module, samples, batch)
+
+ def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ self.save_end()
\ No newline at end of file
diff --git a/src/callbacks/simple_ema.py b/src/callbacks/simple_ema.py
new file mode 100644
index 0000000..28bf476
--- /dev/null
+++ b/src/callbacks/simple_ema.py
@@ -0,0 +1,79 @@
+from typing import Any, Dict
+
+import torch
+import torch.nn as nn
+import threading
+import lightning.pytorch as pl
+from lightning.pytorch import Callback
+from lightning.pytorch.utilities.types import STEP_OUTPUT
+
+from src.utils.copy import swap_tensors
+
+class SimpleEMA(Callback):
+ def __init__(self, net:nn.Module, ema_net:nn.Module,
+ decay: float = 0.9999,
+ every_n_steps: int = 1,
+ eval_original_model:bool = False
+ ):
+ super().__init__()
+ self.decay = decay
+ self.every_n_steps = every_n_steps
+ self.eval_original_model = eval_original_model
+ self._stream = torch.cuda.Stream()
+
+ self.net_params = list(net.parameters())
+ self.ema_params = list(ema_net.parameters())
+
+ def swap_model(self):
+ for ema_p, p, in zip(self.ema_params, self.net_params):
+ swap_tensors(ema_p, p)
+
+ def ema_step(self):
+ @torch.no_grad()
+ def ema_update(ema_model_tuple, current_model_tuple, decay):
+ torch._foreach_mul_(ema_model_tuple, decay)
+ torch._foreach_add_(
+ ema_model_tuple, current_model_tuple, alpha=(1.0 - decay),
+ )
+
+ if self._stream is not None:
+ self._stream.wait_stream(torch.cuda.current_stream())
+ with torch.cuda.stream(self._stream):
+ ema_update(self.ema_params, self.net_params, self.decay)
+
+
+ def on_train_batch_end(
+ self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
+ ) -> None:
+ if trainer.global_step % self.every_n_steps == 0:
+ self.ema_step()
+
+ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ if not self.eval_original_model:
+ self.swap_model()
+
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ if not self.eval_original_model:
+ self.swap_model()
+
+ def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ if not self.eval_original_model:
+ self.swap_model()
+
+ def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ if not self.eval_original_model:
+ self.swap_model()
+
+
+ def state_dict(self) -> Dict[str, Any]:
+ return {
+ "decay": self.decay,
+ "every_n_steps": self.every_n_steps,
+ "eval_original_model": self.eval_original_model,
+ }
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ self.decay = state_dict["decay"]
+ self.every_n_steps = state_dict["every_n_steps"]
+ self.eval_original_model = state_dict["eval_original_model"]
+
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000..8b13789
--- /dev/null
+++ b/src/data/__init__.py
@@ -0,0 +1 @@
+
diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/data/dataset/celeba.py b/src/data/dataset/celeba.py
new file mode 100644
index 0000000..30f5d3f
--- /dev/null
+++ b/src/data/dataset/celeba.py
@@ -0,0 +1,11 @@
+from typing import Callable
+from torchvision.datasets import CelebA
+
+
+class LocalDataset(CelebA):
+ def __init__(self, root:str, ):
+ super(LocalDataset, self).__init__(root, "train")
+
+ def __getitem__(self, idx):
+ data = super().__getitem__(idx)
+ return data
\ No newline at end of file
diff --git a/src/data/dataset/imagenet.py b/src/data/dataset/imagenet.py
new file mode 100644
index 0000000..59d0547
--- /dev/null
+++ b/src/data/dataset/imagenet.py
@@ -0,0 +1,82 @@
+import torch
+from PIL import Image
+from torchvision.datasets import ImageFolder
+from torchvision.transforms.functional import to_tensor
+from torchvision.transforms import Normalize
+
+from src.data.dataset.metric_dataset import CenterCrop
+
+class LocalCachedDataset(ImageFolder):
+ def __init__(self, root, resolution=256):
+ super().__init__(root)
+ self.transform = CenterCrop(resolution)
+ self.cache_root = None
+
+ def load_latent(self, latent_path):
+ pk_data = torch.load(latent_path)
+ mean = pk_data['mean'].to(torch.float32)
+ logvar = pk_data['logvar'].to(torch.float32)
+ logvar = torch.clamp(logvar, -30.0, 20.0)
+ std = torch.exp(0.5 * logvar)
+ latent = mean + torch.randn_like(mean) * std
+ return latent
+
+ def __getitem__(self, idx: int):
+ image_path, target = self.samples[idx]
+ latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
+
+ raw_image = Image.open(image_path).convert('RGB')
+ raw_image = self.transform(raw_image)
+ raw_image = to_tensor(raw_image)
+ if self.cache_root is not None:
+ latent = self.load_latent(latent_path)
+ else:
+ latent = raw_image
+ return raw_image, latent, target
+
+class ImageNet256(LocalCachedDataset):
+ def __init__(self, root, ):
+ super().__init__(root, 256)
+ self.cache_root = root + "_256_latent"
+
+class ImageNet512(LocalCachedDataset):
+ def __init__(self, root, ):
+ super().__init__(root, 512)
+ self.cache_root = root + "_512_latent"
+
+class PixImageNet(ImageFolder):
+ def __init__(self, root, resolution=256):
+ super().__init__(root)
+ self.transform = CenterCrop(resolution)
+ self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ def __getitem__(self, idx: int):
+ image_path, target = self.samples[idx]
+ raw_image = Image.open(image_path).convert('RGB')
+ raw_image = self.transform(raw_image)
+ raw_image = to_tensor(raw_image)
+
+ normalized_image = self.normalize(raw_image)
+ return raw_image, normalized_image, target
+
+class PixImageNet64(PixImageNet):
+ def __init__(self, root, ):
+ super().__init__(root, 64)
+
+class PixImageNet128(PixImageNet):
+ def __init__(self, root, ):
+ super().__init__(root, 128)
+
+
+class PixImageNet256(PixImageNet):
+ def __init__(self, root, ):
+ super().__init__(root, 256)
+
+class PixImageNet512(PixImageNet):
+ def __init__(self, root, ):
+ super().__init__(root, 512)
+
+
+
+
+
diff --git a/src/data/dataset/metric_dataset.py b/src/data/dataset/metric_dataset.py
new file mode 100644
index 0000000..cbe7d66
--- /dev/null
+++ b/src/data/dataset/metric_dataset.py
@@ -0,0 +1,82 @@
+import pathlib
+
+import torch
+import random
+import numpy as np
+from torchvision.io.image import read_image
+import torchvision.transforms as tvtf
+from torch.utils.data import Dataset
+
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+
+from PIL import Image
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+def test_collate(batch):
+ return torch.stack(batch)
+
+class ImageDataset(Dataset):
+ def __init__(self, root, image_size=(224, 224)):
+ self.root = pathlib.Path(root)
+ images = []
+ for ext in IMG_EXTENSIONS:
+ images.extend(self.root.rglob(ext))
+ random.shuffle(images)
+ self.images = list(map(lambda x: str(x), images))
+ self.transform = tvtf.Compose(
+ [
+ CenterCrop(image_size[0]),
+ tvtf.ToTensor(),
+ tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
+ tvtf.Lambda(lambda x: x.expand(3, -1, -1))
+ ]
+ )
+ self.size = image_size
+
+ def __getitem__(self, idx):
+ try:
+ image = Image.open(self.images[idx])
+ image = self.transform(image)
+ except Exception as e:
+ print(self.images[idx])
+ image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
+
+ # print(image)
+ metadata = dict(
+ path = self.images[idx],
+ root = self.root,
+ )
+ return image #, metadata
+
+ def __len__(self):
+ return len(self.images)
\ No newline at end of file
diff --git a/src/data/dataset/randn.py b/src/data/dataset/randn.py
new file mode 100644
index 0000000..f9ec772
--- /dev/null
+++ b/src/data/dataset/randn.py
@@ -0,0 +1,41 @@
+import os.path
+import random
+
+import torch
+from torch.utils.data import Dataset
+
+
+
+class RandomNDataset(Dataset):
+ def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
+ self.selected_classes = selected_classes
+ if selected_classes is not None:
+ num_classes = len(selected_classes)
+ max_num_instances = 10*num_classes
+ self.num_classes = num_classes
+ self.seeds = seeds
+ if seeds is not None:
+ self.max_num_instances = len(seeds)*num_classes
+ self.num_seeds = len(seeds)
+ else:
+ self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
+ self.max_num_instances = self.num_seeds*num_classes
+
+ self.latent_shape = latent_shape
+
+
+ def __getitem__(self, idx):
+ label = idx // self.num_seeds
+ if self.selected_classes:
+ label = self.selected_classes[label]
+ seed = random.randint(0, 1<<31) #idx % self.num_seeds
+ if self.seeds is not None:
+ seed = self.seeds[idx % self.num_seeds]
+
+ # cls_dir = os.path.join(self.root, f"{label}")
+ filename = f"{label}_{seed}.png",
+ generator = torch.Generator().manual_seed(seed)
+ latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
+ return latent, label, filename
+ def __len__(self):
+ return self.max_num_instances
\ No newline at end of file
diff --git a/src/data/var_training.py b/src/data/var_training.py
new file mode 100644
index 0000000..de7fb74
--- /dev/null
+++ b/src/data/var_training.py
@@ -0,0 +1,145 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+import concurrent.futures
+from concurrent.futures import ProcessPoolExecutor
+from typing import List
+from PIL import Image
+import torch
+import random
+import numpy as np
+import copy
+import torchvision.transforms.functional as tvtf
+from src.models.vae import uint82fp
+
+
+def center_crop_arr(pil_image, width, height):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = max(width / pil_image.size[0], height / pil_image.size[1])
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+ arr = np.array(pil_image)
+ crop_y = random.randint(0, (arr.shape[0] - height))
+ crop_x = random.randint(0, (arr.shape[1] - width))
+ return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
+
+def process_fn(width, height, data, hflip=0.5):
+ image, label = data
+ if random.uniform(0, 1) > hflip: # hflip
+ image = tvtf.hflip(image)
+ image = center_crop_arr(image, width, height) # crop
+ image = np.array(image).transpose(2, 0, 1)
+ return image, label
+
+class VARCandidate:
+ def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
+ self.aspect_ratio = aspect_ratio
+ self.width = int(width)
+ self.height = int(height)
+ self.buffer = buffer
+ self.max_buffer_size = max_buffer_size
+
+ def add_sample(self, data):
+ self.buffer.append(data)
+ self.buffer = self.buffer[-self.max_buffer_size:]
+
+ def ready(self, batch_size):
+ return len(self.buffer) >= batch_size
+
+ def get_batch(self, batch_size):
+ batch = self.buffer[:batch_size]
+ self.buffer = self.buffer[batch_size:]
+ batch = [copy.deepcopy(b.result()) for b in batch]
+ x, y = zip(*batch)
+ x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
+ x = list(map(uint82fp, x))
+ return x, y
+
+class VARTransformEngine:
+ def __init__(self,
+ base_image_size,
+ num_aspect_ratios,
+ min_aspect_ratio,
+ max_aspect_ratio,
+ num_workers = 8,
+ ):
+ self.base_image_size = base_image_size
+ self.num_aspect_ratios = num_aspect_ratios
+ self.min_aspect_ratio = min_aspect_ratio
+ self.max_aspect_ratio = max_aspect_ratio
+ self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
+ self.aspect_ratios = self.aspect_ratios.tolist()
+ self.candidates_pool = []
+ for i in range(self.num_aspect_ratios):
+ candidate = VARCandidate(
+ aspect_ratio=self.aspect_ratios[i],
+ width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
+ height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
+ buffer=[],
+ max_buffer_size=1024
+ )
+ self.candidates_pool.append(candidate)
+ self.default_candidate = VARCandidate(
+ aspect_ratio=1.0,
+ width=self.base_image_size,
+ height=self.base_image_size,
+ buffer=[],
+ max_buffer_size=1024,
+ )
+ self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
+ self._prefill_count = 100
+
+ def find_candidate(self, data):
+ image = data[0]
+ aspect_ratio = image.size[0] / image.size[1]
+ min_distance = 1000000
+ min_candidate = None
+ for candidate in self.candidates_pool:
+ dis = abs(aspect_ratio - candidate.aspect_ratio)
+ if dis < min_distance:
+ min_distance = dis
+ min_candidate = candidate
+ return min_candidate
+
+
+ def __call__(self, batch_data):
+ self._prefill_count -= 1
+ if isinstance(batch_data[0], torch.Tensor):
+ batch_data[0] = batch_data[0].unbind(0)
+
+ batch_data = list(zip(*batch_data))
+ for data in batch_data:
+ candidate = self.find_candidate(data)
+ future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
+ candidate.add_sample(future)
+ if self._prefill_count >= 0:
+ future = self.executor_pool.submit(process_fn,
+ self.default_candidate.width,
+ self.default_candidate.height,
+ data)
+ self.default_candidate.add_sample(future)
+
+ batch_size = len(batch_data)
+ random.shuffle(self.candidates_pool)
+ for candidate in self.candidates_pool:
+ if candidate.ready(batch_size=batch_size):
+ return candidate.get_batch(batch_size=batch_size)
+
+ # fallback to default 256
+ for data in batch_data:
+ future = self.executor_pool.submit(process_fn,
+ self.default_candidate.width,
+ self.default_candidate.height,
+ data)
+ self.default_candidate.add_sample(future)
+ return self.default_candidate.get_batch(batch_size=batch_size)
\ No newline at end of file
diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/diffusion/base/guidance.py b/src/diffusion/base/guidance.py
new file mode 100644
index 0000000..07b4754
--- /dev/null
+++ b/src/diffusion/base/guidance.py
@@ -0,0 +1,60 @@
+import torch
+
+def simple_guidance_fn(out, cfg):
+ uncondition, condtion = out.chunk(2, dim=0)
+ out = uncondition + cfg * (condtion - uncondition)
+ return out
+
+def c3_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condtion = out.chunk(2, dim=0)
+ out = condtion
+ out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
+ return out
+
+def c4_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condition = out.chunk(2, dim=0)
+ out = condition
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
+ out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
+ return out
+
+def c4_p05_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condition = out.chunk(2, dim=0)
+ out = condition
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
+ out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:])
+ return out
+
+def c4_p10_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condition = out.chunk(2, dim=0)
+ out = condition
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
+ out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:])
+ return out
+
+def c4_p15_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condition = out.chunk(2, dim=0)
+ out = condition
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
+ out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:])
+ return out
+
+def c4_p20_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condition = out.chunk(2, dim=0)
+ out = condition
+ out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4])
+ out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:])
+ return out
+
+def p4_guidance_fn(out, cfg):
+ # guidance function in DiT/SiT, seems like a bug not a feature?
+ uncondition, condtion = out.chunk(2, dim=0)
+ out = condtion
+ out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:])
+ return out
diff --git a/src/diffusion/base/sampling.py b/src/diffusion/base/sampling.py
new file mode 100644
index 0000000..d8f9776
--- /dev/null
+++ b/src/diffusion/base/sampling.py
@@ -0,0 +1,31 @@
+from typing import Union, List
+
+import torch
+import torch.nn as nn
+from typing import Callable
+from src.diffusion.base.scheduling import BaseScheduler
+
+class BaseSampler(nn.Module):
+ def __init__(self,
+ scheduler: BaseScheduler = None,
+ guidance_fn: Callable = None,
+ num_steps: int = 250,
+ guidance: Union[float, List[float]] = 1.0,
+ *args,
+ **kwargs
+ ):
+ super(BaseSampler, self).__init__()
+ self.num_steps = num_steps
+ self.guidance = guidance
+ self.guidance_fn = guidance_fn
+ self.scheduler = scheduler
+
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ raise NotImplementedError
+
+ def __call__(self, net, noise, condition, uncondition):
+ denoised = self._impl_sampling(net, noise, condition, uncondition)
+ return denoised
+
+
diff --git a/src/diffusion/base/scheduling.py b/src/diffusion/base/scheduling.py
new file mode 100644
index 0000000..05c7fb1
--- /dev/null
+++ b/src/diffusion/base/scheduling.py
@@ -0,0 +1,32 @@
+import torch
+from torch import Tensor
+
+class BaseScheduler:
+ def alpha(self, t) -> Tensor:
+ ...
+ def sigma(self, t) -> Tensor:
+ ...
+
+ def dalpha(self, t) -> Tensor:
+ ...
+ def dsigma(self, t) -> Tensor:
+ ...
+
+ def dalpha_over_alpha(self, t) -> Tensor:
+ return self.dalpha(t) / self.alpha(t)
+
+ def dsigma_mul_sigma(self, t) -> Tensor:
+ return self.dsigma(t)*self.sigma(t)
+
+ def drift_coefficient(self, t):
+ alpha, sigma = self.alpha(t), self.sigma(t)
+ dalpha, dsigma = self.dalpha(t), self.dsigma(t)
+ return dalpha/(alpha + 1e-6)
+
+ def diffuse_coefficient(self, t):
+ alpha, sigma = self.alpha(t), self.sigma(t)
+ dalpha, dsigma = self.dalpha(t), self.dsigma(t)
+ return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2
+
+ def w(self, t):
+ return self.sigma(t)
diff --git a/src/diffusion/base/training.py b/src/diffusion/base/training.py
new file mode 100644
index 0000000..8f6d0e0
--- /dev/null
+++ b/src/diffusion/base/training.py
@@ -0,0 +1,29 @@
+import time
+
+import torch
+import torch.nn as nn
+
+class BaseTrainer(nn.Module):
+ def __init__(self,
+ null_condition_p=0.1,
+ log_var=False,
+ ):
+ super(BaseTrainer, self).__init__()
+ self.null_condition_p = null_condition_p
+ self.log_var = log_var
+
+ def preproprocess(self, raw_iamges, x, condition, uncondition):
+ bsz = x.shape[0]
+ if self.null_condition_p > 0:
+ mask = torch.rand((bsz), device=condition.device) < self.null_condition_p
+ mask = mask.expand_as(condition)
+ condition[mask] = uncondition[mask]
+ return raw_iamges, x, condition
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ raise NotImplementedError
+
+ def __call__(self, net, ema_net, raw_images, x, condition, uncondition):
+ raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition)
+ return self._impl_trainstep(net, ema_net, raw_images, x, condition)
+
diff --git a/src/diffusion/ddpm/ddim_sampling.py b/src/diffusion/ddpm/ddim_sampling.py
new file mode 100644
index 0000000..0db2a1d
--- /dev/null
+++ b/src/diffusion/ddpm/ddim_sampling.py
@@ -0,0 +1,40 @@
+import torch
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+
+from typing import Callable
+
+import logging
+logger = logging.getLogger(__name__)
+
+class DDIMSampler(BaseSampler):
+ def __init__(
+ self,
+ train_num_steps=1000,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.train_num_steps = train_num_steps
+ assert self.scheduler is not None
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ batch_size = noise.shape[0]
+ steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device)
+ steps = torch.flip(steps, dims=[0])
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = x0 = noise
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ t_cur = t_cur.repeat(batch_size)
+ t_next = t_next.repeat(batch_size)
+ sigma = self.scheduler.sigma(t_cur)
+ alpha = self.scheduler.alpha(t_cur)
+ sigma_next = self.scheduler.sigma(t_next)
+ alpha_next = self.scheduler.alpha(t_next)
+ cfg_x = torch.cat([x, x], dim=0)
+ t = t_cur.repeat(2)
+ out = net(cfg_x, t, cfg_condition)
+ out = self.guidance_fn(out, self.guidance)
+ x0 = (x - sigma * out) / alpha
+ x = alpha_next * x0 + sigma_next * out
+ return x0
\ No newline at end of file
diff --git a/src/diffusion/ddpm/scheduling.py b/src/diffusion/ddpm/scheduling.py
new file mode 100644
index 0000000..aff1523
--- /dev/null
+++ b/src/diffusion/ddpm/scheduling.py
@@ -0,0 +1,102 @@
+import math
+import torch
+from src.diffusion.base.scheduling import *
+
+
+class DDPMScheduler(BaseScheduler):
+ def __init__(
+ self,
+ beta_min=0.0001,
+ beta_max=0.02,
+ num_steps=1000,
+ ):
+ super().__init__()
+ self.beta_min = beta_min
+ self.beta_max = beta_max
+ self.num_steps = num_steps
+
+ self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda")
+ self.alphas_table = torch.cumprod(1-self.betas_table, dim=0)
+ self.sigmas_table = 1-self.alphas_table
+
+
+ def beta(self, t) -> Tensor:
+ t = t.to(torch.long)
+ return self.betas_table[t].view(-1, 1, 1, 1)
+
+ def alpha(self, t) -> Tensor:
+ t = t.to(torch.long)
+ return self.alphas_table[t].view(-1, 1, 1, 1)**0.5
+
+ def sigma(self, t) -> Tensor:
+ t = t.to(torch.long)
+ return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5
+
+ def dsigma(self, t) -> Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dalpha_over_alpha(self, t) ->Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dsigma_mul_sigma(self, t) ->Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dalpha(self, t) -> Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def drift_coefficient(self, t):
+ raise NotImplementedError("wrong usage")
+
+ def diffuse_coefficient(self, t):
+ raise NotImplementedError("wrong usage")
+
+ def w(self, t):
+ raise NotImplementedError("wrong usage")
+
+
+class VPScheduler(BaseScheduler):
+ def __init__(
+ self,
+ beta_min=0.1,
+ beta_max=20,
+ ):
+ super().__init__()
+ self.beta_min = beta_min
+ self.beta_d = beta_max - beta_min
+ def beta(self, t) -> Tensor:
+ t = torch.clamp(t, min=1e-3, max=1)
+ return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1)
+
+ def sigma(self, t) -> Tensor:
+ t = torch.clamp(t, min=1e-3, max=1)
+ inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t
+ return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1)
+
+ def dsigma(self, t) -> Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dalpha_over_alpha(self, t) ->Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dsigma_mul_sigma(self, t) ->Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def dalpha(self, t) -> Tensor:
+ raise NotImplementedError("wrong usage")
+
+ def alpha(self, t) -> Tensor:
+ t = torch.clamp(t, min=1e-3, max=1)
+ inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t
+ return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1)
+
+ def drift_coefficient(self, t):
+ raise NotImplementedError("wrong usage")
+
+ def diffuse_coefficient(self, t):
+ raise NotImplementedError("wrong usage")
+
+ def w(self, t):
+ return self.diffuse_coefficient(t)
+
+
+
diff --git a/src/diffusion/ddpm/training.py b/src/diffusion/ddpm/training.py
new file mode 100644
index 0000000..3e0d0ec
--- /dev/null
+++ b/src/diffusion/ddpm/training.py
@@ -0,0 +1,83 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class VPTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ train_max_t=1000,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.train_max_t = train_max_t
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ sigma = self.scheduler.sigma(t)
+ x_t = alpha * x + noise * sigma
+ out = net(x_t, t*self.train_max_t, y)
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - noise)**2
+
+ out = dict(
+ loss=loss.mean(),
+ )
+ return out
+
+
+class DDPMTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn: Callable = constant,
+ train_max_t=1000,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.train_max_t = train_max_t
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ t = torch.randint(0, self.train_max_t, (batch_size,))
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ sigma = self.scheduler.sigma(t)
+ x_t = alpha * x + noise * sigma
+ out = net(x_t, t, y)
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight * (out - noise) ** 2
+
+ out = dict(
+ loss=loss.mean(),
+ )
+ return out
\ No newline at end of file
diff --git a/src/diffusion/ddpm/vp_sampling.py b/src/diffusion/ddpm/vp_sampling.py
new file mode 100644
index 0000000..250b32d
--- /dev/null
+++ b/src/diffusion/ddpm/vp_sampling.py
@@ -0,0 +1,59 @@
+import torch
+
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+from typing import Callable
+
+def ode_step_fn(x, eps, beta, sigma, dt):
+ return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt
+
+def sde_step_fn(x, eps, beta, sigma, dt):
+ return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x)
+
+import logging
+logger = logging.getLogger(__name__)
+
+class VPEulerSampler(BaseSampler):
+ def __init__(
+ self,
+ train_max_t=1000,
+ guidance_fn: Callable = None,
+ step_fn: Callable = ode_step_fn,
+ last_step=None,
+ last_step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.guidance_fn = guidance_fn
+ self.step_fn = step_fn
+ self.last_step = last_step
+ self.last_step_fn = last_step_fn
+ self.train_max_t = train_max_t
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ batch_size = noise.shape[0]
+ steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device)
+ steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ dt = t_next - t_cur
+ t_cur = t_cur.repeat(batch_size)
+ sigma = self.scheduler.sigma(t_cur)
+ beta = self.scheduler.beta(t_cur)
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition)
+ eps = self.guidance_fn(out, self.guidance)
+ if i < self.num_steps -1 :
+ x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0])
+ x = self.step_fn(x, eps, beta, sigma, dt)
+ else:
+ x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step)
+ return x
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py
new file mode 100644
index 0000000..15d0c78
--- /dev/null
+++ b/src/diffusion/flow_matching/adam_sampling.py
@@ -0,0 +1,107 @@
+import math
+from src.diffusion.base.sampling import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.pre_integral import *
+
+from typing import Callable, List, Tuple
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+def t2snr(t):
+ if isinstance(t, torch.Tensor):
+ return (t.clip(min=1e-8)/(1-t + 1e-8))
+ if isinstance(t, List) or isinstance(t, Tuple):
+ return [t2snr(t) for t in t]
+ t = max(t, 1e-8)
+ return (t/(1-t + 1e-8))
+
+def t2logsnr(t):
+ if isinstance(t, torch.Tensor):
+ return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
+ if isinstance(t, List) or isinstance(t, Tuple):
+ return [t2logsnr(t) for t in t]
+ t = max(t, 1e-3)
+ return math.log(t/(1-t + 1e-3))
+
+def t2isnr(t):
+ return 1/t2snr(t)
+
+def nop(t):
+ return t
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+import logging
+logger = logging.getLogger(__name__)
+
+class AdamLMSampler(BaseSampler):
+ def __init__(
+ self,
+ order: int = 2,
+ timeshift: float = 1.0,
+ lms_transform_fn: Callable = nop,
+ w_scheduler: BaseScheduler = None,
+ step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.step_fn = step_fn
+ self.w_scheduler = w_scheduler
+
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ self.order = order
+ self.lms_transform_fn = lms_transform_fn
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, timeshift)
+ self.timedeltas = timesteps[1:] - self.timesteps[:-1]
+ self._reparameterize_coeffs()
+
+ def _reparameterize_coeffs(self):
+ solver_coeffs = [[] for _ in range(self.num_steps)]
+ for i in range(0, self.num_steps):
+ pre_vs = [1.0, ]*(i+1)
+ pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
+ int_t_start = self.lms_transform_fn(self.timesteps[i])
+ int_t_end = self.lms_transform_fn(self.timesteps[i+1])
+
+ order_annealing = self.order #self.num_steps - i
+ order = min(self.order, i + 1, order_annealing)
+
+ _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
+ solver_coeffs[i] = coeffs
+ self.solver_coeffs = solver_coeffs
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = x0 = noise
+ pred_trajectory = []
+ t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
+ timedeltas = self.timedeltas
+ solver_coeffs = self.solver_coeffs
+ for i in range(self.num_steps):
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ out = net(cfg_x, cfg_t, cfg_condition)
+ out = self.guidance_fn(out, self.guidances[i])
+ pred_trajectory.append(out)
+ out = torch.zeros_like(out)
+ order = len(self.solver_coeffs[i])
+ for j in range(order):
+ out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
+ v = out
+ dt = timedeltas[i]
+ x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
+ x = self.step_fn(x, v, dt, s=0, w=0)
+ t_cur += dt
+ return x
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/sampling.py b/src/diffusion/flow_matching/sampling.py
new file mode 100644
index 0000000..62bdd8b
--- /dev/null
+++ b/src/diffusion/flow_matching/sampling.py
@@ -0,0 +1,179 @@
+import torch
+
+from src.diffusion.base.guidance import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+
+from typing import Callable
+
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+def sde_mean_step_fn(x, v, dt, s, w):
+ return x + v * dt + s * w * dt
+
+def sde_step_fn(x, v, dt, s, w):
+ return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
+
+def sde_preserve_step_fn(x, v, dt, s, w):
+ return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+class EulerSampler(BaseSampler):
+ def __init__(
+ self,
+ w_scheduler: BaseScheduler = None,
+ timeshift=1.0,
+ step_fn: Callable = ode_step_fn,
+ last_step=None,
+ last_step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.step_fn = step_fn
+ self.last_step = last_step
+ self.last_step_fn = last_step_fn
+ self.w_scheduler = w_scheduler
+ self.timeshift = timeshift
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
+
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ if self.w_scheduler is not None:
+ if self.step_fn == ode_step_fn:
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ steps = self.timesteps.to(noise.device)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ dt = t_next - t_cur
+ t_cur = t_cur.repeat(batch_size)
+ sigma = self.scheduler.sigma(t_cur)
+ dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
+ if self.w_scheduler:
+ w = self.w_scheduler.w(t_cur)
+ else:
+ w = 0.0
+
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ out = net(cfg_x, cfg_t, cfg_condition)
+ out = self.guidance_fn(out, self.guidance)
+ v = out
+ s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
+ if i < self.num_steps -1 :
+ x = self.step_fn(x, v, dt, s=s, w=w)
+ else:
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
+ return x
+
+
+class HeunSampler(BaseSampler):
+ def __init__(
+ self,
+ scheduler: BaseScheduler = None,
+ w_scheduler: BaseScheduler = None,
+ exact_henu=False,
+ timeshift=1.0,
+ step_fn: Callable = ode_step_fn,
+ last_step=None,
+ last_step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.scheduler = scheduler
+ self.exact_henu = exact_henu
+ self.step_fn = step_fn
+ self.last_step = last_step
+ self.last_step_fn = last_step_fn
+ self.w_scheduler = w_scheduler
+ self.timeshift = timeshift
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ if self.w_scheduler is not None:
+ if self.step_fn == ode_step_fn:
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Henu sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ steps = self.timesteps.to(noise.device)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ v_hat, s_hat = 0.0, 0.0
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ dt = t_next - t_cur
+ t_cur = t_cur.repeat(batch_size)
+ sigma = self.scheduler.sigma(t_cur)
+ alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
+ t_hat = t_next
+ t_hat = t_hat.repeat(batch_size)
+ sigma_hat = self.scheduler.sigma(t_hat)
+ alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
+ dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)
+
+ if self.w_scheduler:
+ w = self.w_scheduler.w(t_cur)
+ else:
+ w = 0.0
+ if i == 0 or self.exact_henu:
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t_cur = t_cur.repeat(2)
+ out = net(cfg_x, cfg_t_cur, cfg_condition)
+ out = self.guidance_fn(out, self.guidance)
+ v = out
+ s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
+ else:
+ v = v_hat
+ s = s_hat
+ x_hat = self.step_fn(x, v, dt, s=s, w=w)
+ # henu correct
+ if i < self.num_steps -1:
+ cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
+ cfg_t_hat = t_hat.repeat(2)
+ out = net(cfg_x_hat, cfg_t_hat, cfg_condition)
+ out = self.guidance_fn(out, self.guidance)
+ v_hat = out
+ s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
+ v = (v + v_hat) / 2
+ s = (s + s_hat) / 2
+ x = self.step_fn(x, v, dt, s=s, w=w)
+ else:
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
+ return x
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/scheduling.py b/src/diffusion/flow_matching/scheduling.py
new file mode 100644
index 0000000..a82cd3a
--- /dev/null
+++ b/src/diffusion/flow_matching/scheduling.py
@@ -0,0 +1,39 @@
+import math
+import torch
+from src.diffusion.base.scheduling import *
+
+
+class LinearScheduler(BaseScheduler):
+ def alpha(self, t) -> Tensor:
+ return (t).view(-1, 1, 1, 1)
+ def sigma(self, t) -> Tensor:
+ return (1-t).view(-1, 1, 1, 1)
+ def dalpha(self, t) -> Tensor:
+ return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
+ def dsigma(self, t) -> Tensor:
+ return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
+
+# SoTA for ImageNet!
+class GVPScheduler(BaseScheduler):
+ def alpha(self, t) -> Tensor:
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def sigma(self, t) -> Tensor:
+ return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def dalpha(self, t) -> Tensor:
+ return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def dsigma(self, t) -> Tensor:
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def w(self, t):
+ return torch.sin(t)**2
+
+class ConstScheduler(BaseScheduler):
+ def w(self, t):
+ return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
+
+from src.diffusion.ddpm.scheduling import VPScheduler
+class VPBetaScheduler(VPScheduler):
+ def w(self, t):
+ return self.beta(t).view(-1, 1, 1, 1)
+
+
+
diff --git a/src/diffusion/flow_matching/training.py b/src/diffusion/flow_matching/training.py
new file mode 100644
index 0000000..55c964d
--- /dev/null
+++ b/src/diffusion/flow_matching/training.py
@@ -0,0 +1,55 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class FlowMatchingTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ out = net(x_t, t, y)
+
+ weight = self.loss_weight_fn(alpha, sigma)
+
+ loss = weight*(out - v_t)**2
+
+ out = dict(
+ loss=loss.mean(),
+ )
+ return out
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/training_cos.py b/src/diffusion/flow_matching/training_cos.py
new file mode 100644
index 0000000..aff30a7
--- /dev/null
+++ b/src/diffusion/flow_matching/training_cos.py
@@ -0,0 +1,59 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class COSTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ out = net(x_t, t, y)
+
+ weight = self.loss_weight_fn(alpha, sigma)
+
+ fm_loss = weight*(out - v_t)**2
+ cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1)
+ cos_loss = 1 - cos_sim
+
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + cos_loss.mean(),
+ )
+ return out
\ No newline at end of file
diff --git a/src/diffusion/flow_matching/training_pyramid.py b/src/diffusion/flow_matching/training_pyramid.py
new file mode 100644
index 0000000..be2bd94
--- /dev/null
+++ b/src/diffusion/flow_matching/training_pyramid.py
@@ -0,0 +1,68 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class PyramidTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+
+ output_pyramid = []
+ def feature_hook(module, input, output):
+ output_pyramid.extend(output)
+ handle = net.decoder.register_forward_hook(feature_hook)
+ net(x_t, t, y)
+ handle.remove()
+
+ loss = 0.0
+ out_dict = dict()
+
+ cur_v_t = v_t
+ for i in range(len(output_pyramid)):
+ cur_out = output_pyramid[i]
+ loss_i = (cur_v_t - cur_out) ** 2
+ loss += loss_i.mean()
+ out_dict["loss_{}".format(i)] = loss_i.mean()
+ cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False)
+ out_dict["loss"] = loss
+ return out_dict
+
diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py
new file mode 100644
index 0000000..e9a6788
--- /dev/null
+++ b/src/diffusion/flow_matching/training_repa.py
@@ -0,0 +1,142 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class REPATrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.encoder = DINOv2(encoder_weight_path)
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+
+ out = net(x_t, t, y)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = weight*(out - v_t)**2
+
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/flow_matching/training_repa_mask.py b/src/diffusion/flow_matching/training_repa_mask.py
new file mode 100644
index 0000000..f8c4edb
--- /dev/null
+++ b/src/diffusion/flow_matching/training_repa_mask.py
@@ -0,0 +1,152 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class REPATrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ mask_ratio=0.0,
+ mask_patch_size=2,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.mask_ratio = mask_ratio
+ self.mask_patch_size = mask_patch_size
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.encoder = DINOv2(encoder_weight_path)
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+
+ patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device)
+ patch_mask = (patch_mask < self.mask_ratio).float()
+ mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest')
+ masked_x = x*(1-mask)# + torch.randn_like(x)*(mask)
+
+ x_t = alpha*masked_x + sigma*noise
+ v_t = dalpha*x + dsigma*noise
+
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+
+ v_t_out, x0_out = net(x_t, t, y)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean())
+ mask_loss = mask*weight*(x0_out - x)**2/(mask.mean())
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ mask_loss=mask_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/pre_integral.py b/src/diffusion/pre_integral.py
new file mode 100644
index 0000000..848533a
--- /dev/null
+++ b/src/diffusion/pre_integral.py
@@ -0,0 +1,143 @@
+import torch
+
+# lagrange interpolation
+def lagrange_preint_o1(t1, v1, int_t_start, int_t_end):
+ '''
+ lagrange interpolation of order 1
+ Args:
+ t1: timestepx
+ v1: value field at t1
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ int1 = (int_t_end-int_t_start)
+ return int1*v1, (int1/int1, )
+
+def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end):
+ '''
+ lagrange interpolation of order 2
+ Args:
+ t1: timestepx
+ t2: timestepy
+ v1: value field at t1
+ v2: value field at t2
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2)
+ int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2)
+ int_sum = int1+int2
+ return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum)
+
+def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end):
+ '''
+ lagrange interpolation of order 3
+ Args:
+ t1: timestepx
+ t2: timestepy
+ t3: timestepz
+ v1: value field at t1
+ v2: value field at t2
+ v3: value field at t3
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ int1_denom = (t1-t2)*(t1-t3)
+ int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end
+ int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start
+ int1 = (int1_end - int1_start)/int1_denom
+ int2_denom = (t2-t1)*(t2-t3)
+ int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end
+ int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start
+ int2 = (int2_end - int2_start)/int2_denom
+ int3_denom = (t3-t1)*(t3-t2)
+ int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end
+ int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start
+ int3 = (int3_end - int3_start)/int3_denom
+ int_sum = int1+int2+int3
+ return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum)
+
+def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end):
+ '''
+ lagrange interpolation of order 4
+ Args:
+ t1: timestepx
+ t2: timestepy
+ t3: timestepz
+ t4: timestepw
+ v1: value field at t1
+ v2: value field at t2
+ v3: value field at t3
+ v4: value field at t4
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ int1_denom = (t1-t2)*(t1-t3)*(t1-t4)
+ int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end
+ int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start
+ int1 = (int1_end - int1_start)/int1_denom
+ int2_denom = (t2-t1)*(t2-t3)*(t2-t4)
+ int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end
+ int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start
+ int2 = (int2_end - int2_start)/int2_denom
+ int3_denom = (t3-t1)*(t3-t2)*(t3-t4)
+ int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end
+ int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start
+ int3 = (int3_end - int3_start)/int3_denom
+ int4_denom = (t4-t1)*(t4-t2)*(t4-t3)
+ int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end
+ int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start
+ int4 = (int4_end - int4_start)/int4_denom
+ int_sum = int1+int2+int3+int4
+ return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum)
+
+
+def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end):
+ '''
+ lagrange interpolation
+ Args:
+ order: order of interpolation
+ pre_vs: value field at pre_ts
+ pre_ts: timesteps
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ order = min(order, len(pre_vs), len(pre_ts))
+ if order == 1:
+ return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end)
+ elif order == 2:
+ return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
+ elif order == 3:
+ return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
+ elif order == 4:
+ return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end)
+ else:
+ raise ValueError('Invalid order')
+
+
+def polynomial_integral(coeffs, int_t_start, int_t_end):
+ '''
+ polynomial integral
+ Args:
+ coeffs: coefficients of the polynomial
+ int_t_start: intergation start time
+ int_t_end: intergation end time
+ Returns:
+ integrated value
+ '''
+ orders = len(coeffs)
+ int_val = 0
+ for o in range(orders):
+ int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1))
+ return int_val
+
diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py
new file mode 100644
index 0000000..fb2e95b
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/adam_sampling.py
@@ -0,0 +1,112 @@
+import math
+from src.diffusion.base.sampling import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.pre_integral import *
+
+from typing import Callable, List, Tuple
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+def t2snr(t):
+ if isinstance(t, torch.Tensor):
+ return (t.clip(min=1e-8)/(1-t + 1e-8))
+ if isinstance(t, List) or isinstance(t, Tuple):
+ return [t2snr(t) for t in t]
+ t = max(t, 1e-8)
+ return (t/(1-t + 1e-8))
+
+def t2logsnr(t):
+ if isinstance(t, torch.Tensor):
+ return torch.log(t.clip(min=1e-3)/(1-t + 1e-3))
+ if isinstance(t, List) or isinstance(t, Tuple):
+ return [t2logsnr(t) for t in t]
+ t = max(t, 1e-3)
+ return math.log(t/(1-t + 1e-3))
+
+def t2isnr(t):
+ return 1/t2snr(t)
+
+def nop(t):
+ return t
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+import logging
+logger = logging.getLogger(__name__)
+
+class AdamLMSampler(BaseSampler):
+ def __init__(
+ self,
+ order: int = 2,
+ timeshift: float = 1.0,
+ state_refresh_rate: int = 1,
+ lms_transform_fn: Callable = nop,
+ w_scheduler: BaseScheduler = None,
+ step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.step_fn = step_fn
+ self.w_scheduler = w_scheduler
+ self.state_refresh_rate = state_refresh_rate
+
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ self.order = order
+ self.lms_transform_fn = lms_transform_fn
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, timeshift)
+ self.timedeltas = timesteps[1:] - self.timesteps[:-1]
+ self._reparameterize_coeffs()
+
+ def _reparameterize_coeffs(self):
+ solver_coeffs = [[] for _ in range(self.num_steps)]
+ for i in range(0, self.num_steps):
+ pre_vs = [1.0, ]*(i+1)
+ pre_ts = self.lms_transform_fn(self.timesteps[:i+1])
+ int_t_start = self.lms_transform_fn(self.timesteps[i])
+ int_t_end = self.lms_transform_fn(self.timesteps[i+1])
+
+ order_annealing = self.order #self.num_steps - i
+ order = min(self.order, i + 1, order_annealing)
+
+ _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end)
+ solver_coeffs[i] = coeffs
+ self.solver_coeffs = solver_coeffs
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = x0 = noise
+ state = None
+ pred_trajectory = []
+ t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype)
+ timedeltas = self.timedeltas
+ solver_coeffs = self.solver_coeffs
+ for i in range(self.num_steps):
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ if i % self.state_refresh_rate == 0:
+ state = None
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
+ out = self.guidance_fn(out, self.guidances[i])
+ pred_trajectory.append(out)
+ out = torch.zeros_like(out)
+ order = len(self.solver_coeffs[i])
+ for j in range(order):
+ out += solver_coeffs[i][j] * pred_trajectory[-order:][j]
+ v = out
+ dt = timedeltas[i]
+ x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0)
+ x = self.step_fn(x, v, dt, s=0, w=0)
+ t_cur += dt
+ return x
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv.py b/src/diffusion/stateful_flow_matching/bak/training_adv.py
new file mode 100644
index 0000000..4792950
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_adv.py
@@ -0,0 +1,122 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class Discriminator(nn.Module):
+ def __init__(self, in_channels, hidden_size):
+ super().__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
+ )
+
+ def forward(self, feature):
+ B, L, C = feature.shape
+ H = W = int(math.sqrt(L))
+ feature = feature.permute(0, 2, 1)
+ feature = feature.view(B, C, H, W)
+ out = self.head(feature).sigmoid().clamp(0.01, 0.99)
+ return out
+
+class AdvTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ adv_weight=1.0,
+ adv_encoder_layer=4,
+ adv_in_channels=768,
+ adv_hidden_size=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.adv_weight = adv_weight
+ self.adv_encoder_layer = adv_encoder_layer
+
+ self.dis_head = Discriminator(
+ in_channels=adv_in_channels,
+ hidden_size=adv_hidden_size,
+ )
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ adv_feature = []
+ def forward_hook(net, input, output):
+ adv_feature.append(output)
+ handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
+
+ out, _ = net(x_t, t, y)
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+
+ pred_x0 = (x_t + out * sigma)
+ pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
+ real_feature = adv_feature.pop()
+ net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
+ fake_feature = adv_feature.pop()
+ handle.remove()
+
+
+ real_score_gan = self.dis_head(real_feature.detach())
+ fake_score_gan = self.dis_head(fake_feature.detach())
+ fake_score = self.dis_head(fake_feature)
+
+ loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
+ acc_real = (real_score_gan > 0.5).float()
+ acc_fake = (fake_score_gan < 0.5).float()
+ loss_adv = -torch.log(fake_score)
+ loss_adv_hack = torch.log(fake_score_gan)
+
+ out = dict(
+ adv_loss=loss_adv.mean(),
+ gan_loss=loss_gan.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
+ acc_real=acc_real.mean(),
+ acc_fake=acc_fake.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
new file mode 100644
index 0000000..2843c04
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py
@@ -0,0 +1,127 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class Discriminator(nn.Module):
+ def __init__(self, in_channels, hidden_size):
+ super().__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
+ )
+
+ def forward(self, feature):
+ B, L, C = feature.shape
+ H = W = int(math.sqrt(L))
+ feature = feature.permute(0, 2, 1)
+ feature = feature.view(B, C, H, W)
+ out = self.head(feature).sigmoid().clamp(0.01, 0.99)
+ return out
+
+class AdvTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ adv_weight=1.0,
+ lpips_weight=1.0,
+ adv_encoder_layer=4,
+ adv_in_channels=768,
+ adv_hidden_size=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.adv_weight = adv_weight
+ self.lpips_weight = lpips_weight
+ self.adv_encoder_layer = adv_encoder_layer
+
+ self.dis_head = Discriminator(
+ in_channels=adv_in_channels,
+ hidden_size=adv_hidden_size,
+ )
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ out, _ = net(x_t, t, y)
+ pred_x0 = (x_t + out * sigma)
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+ with torch.no_grad():
+ _, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer)
+ _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer)
+
+ real_score_gan = self.dis_head(real_features[-1].detach())
+ fake_score_gan = self.dis_head(fake_features[-1].detach())
+ fake_score = self.dis_head(fake_features[-1])
+
+ loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
+ acc_real = (real_score_gan > 0.5).float()
+ acc_fake = (fake_score_gan < 0.5).float()
+ loss_adv = -torch.log(fake_score)
+ loss_adv_hack = torch.log(fake_score_gan)
+
+ lpips_loss = []
+ for r, f in zip(real_features, fake_features):
+ r = torch.nn.functional.normalize(r, dim=-1)
+ f = torch.nn.functional.normalize(f, dim=-1)
+ lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
+ lpips_loss = sum(lpips_loss)
+
+
+ out = dict(
+ adv_loss=loss_adv.mean(),
+ gan_loss=loss_gan.mean(),
+ lpips_loss=lpips_loss.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(),
+ acc_real=acc_real.mean(),
+ acc_fake=acc_fake.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
new file mode 100644
index 0000000..849ee4b
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py
@@ -0,0 +1,159 @@
+import random
+
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class MaskREPATrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ encoder_weight_path=None,
+ mask_groups=4,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.mask_groups = mask_groups
+ self.encoder = DINOv2(encoder_weight_path)
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')):
+ mask = torch.zeros(1, length, length, device=device, dtype=torch.bool)
+ random_seq = torch.randperm(length, device=device)
+ for i in range(groups):
+ group_start = (length+groups-1)//groups*i
+ group_end = (length+groups-1)//groups*(i+1)
+ group_random_seq = random_seq[group_start:group_end]
+ y, x = torch.meshgrid(group_random_seq, group_random_seq)
+ mask[:, y, x] = True
+ return mask
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+ mask_groups = random.randint(1, self.mask_groups)
+ mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device)
+ out, _ = net(x_t, t, y, mask=mask)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = weight*(out - v_t)**2
+
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
new file mode 100644
index 0000000..229680c
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py
@@ -0,0 +1,179 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t, mul=1000):
+ t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class BatchNormWithTimeEmbedding(nn.Module):
+ def __init__(self, num_features):
+ super().__init__()
+ # self.bn = nn.BatchNorm2d(num_features, affine=False)
+ self.bn = nn.GroupNorm(16, num_features, affine=False)
+ # self.bn = nn.SyncBatchNorm(num_features, affine=False)
+ self.embedder = TimestepEmbedder(num_features * 2)
+ # nn.init.zeros_(self.embedder.mlp[-1].weight)
+ nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01)
+ nn.init.zeros_(self.embedder.mlp[-1].bias)
+
+ def forward(self, x, t):
+ embed = self.embedder(t)
+ embed = embed[:, :, None, None]
+ gamma, beta = embed.chunk(2, dim=1)
+ gamma = 1.0 + gamma
+ normed = self.bn(x)
+ out = normed * gamma + beta
+ return out
+
+class DisBlock(nn.Module):
+ def __init__(self, in_channels, hidden_size):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0
+ )
+ self.norm = BatchNormWithTimeEmbedding(hidden_size)
+ self.act = nn.SiLU()
+ def forward(self, x, t):
+ x = self.conv(x)
+ x = self.norm(x, t)
+ x = self.act(x)
+ return x
+
+
+class Discriminator(nn.Module):
+ def __init__(self, num_blocks, in_channels, hidden_size):
+ super().__init__()
+ self.blocks = nn.ModuleList()
+ for i in range(num_blocks):
+ self.blocks.append(
+ DisBlock(
+ in_channels=in_channels,
+ hidden_size=hidden_size,
+ )
+ )
+ in_channels = hidden_size
+ self.classifier = nn.Conv2d(
+ kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1
+ )
+ def forward(self, feature, t):
+ B, C, H, W = feature.shape
+ for block in self.blocks:
+ feature = block(feature, t)
+ out = self.classifier(feature).view(B, -1)
+ out = out.sigmoid().clamp(0.01, 0.99)
+ return out
+
+class AdvTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ adv_weight=1.0,
+ adv_blocks=3,
+ adv_in_channels=3,
+ adv_hidden_size=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.adv_weight = adv_weight
+
+ self.discriminator = Discriminator(
+ num_blocks=adv_blocks,
+ in_channels=adv_in_channels*2,
+ hidden_size=adv_hidden_size,
+ )
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ out, _ = net(x_t, t, y)
+ pred_x0 = x_t + sigma * out
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+
+ real_feature = torch.cat([x_t, x], dim=1)
+ fake_feature = torch.cat([x_t, pred_x0], dim=1)
+
+ real_score_gan = self.discriminator(real_feature.detach(), t)
+ fake_score_gan = self.discriminator(fake_feature.detach(), t)
+ fake_score = self.discriminator(fake_feature, t)
+
+ loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
+ acc_real = (real_score_gan > 0.5).float()
+ acc_fake = (fake_score_gan < 0.5).float()
+ loss_adv = -torch.log(fake_score)
+ loss_adv_hack = torch.log(fake_score_gan)
+
+ out = dict(
+ adv_loss=loss_adv.mean(),
+ gan_loss=loss_gan.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
+ acc_real=acc_real.mean(),
+ acc_fake=acc_fake.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
new file mode 100644
index 0000000..e84e81f
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py
@@ -0,0 +1,154 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class REPAJiTTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ jit_deltas=0.01,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.jit_deltas = jit_deltas
+ self.encoder = DINOv2(encoder_weight_path)
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas
+ t2 = torch.clip(t2, 0, 1)
+ alpha = self.scheduler.alpha(t2)
+ dalpha = self.scheduler.dalpha(t2)
+ sigma = self.scheduler.sigma(t2)
+ dsigma = self.scheduler.dsigma(t2)
+ x_t2 = alpha * x + noise * sigma
+ v_t2 = dalpha * x + dsigma * noise
+
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+
+ _, s = net(x_t, t, y, only_s=True)
+ out, _ = net(x_t2, t2, y, s)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = weight*(out - v_t2)**2
+
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py
new file mode 100644
index 0000000..d7e741d
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py
@@ -0,0 +1,90 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class SelfConsistentTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ lpips_weight=1.0,
+ lpips_encoder_layer=4,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.lpips_encoder_layer = lpips_encoder_layer
+ self.lpips_weight = lpips_weight
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ real_features = []
+ def forward_hook(net, input, output):
+ real_features.append(output)
+ handles = []
+ for i in range(self.lpips_encoder_layer):
+ handle = net.encoder.blocks[i].register_forward_hook(forward_hook)
+ handles.append(handle)
+
+ out, _ = net(x_t, t, y)
+
+ for handle in handles:
+ handle.remove()
+
+ pred_x0 = (x_t + out * sigma)
+ pred_xt = alpha * pred_x0 + noise * sigma
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+
+ _, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer)
+
+ lpips_loss = []
+ for r, f in zip(real_features, fake_features):
+ r = torch.nn.functional.normalize(r, dim=-1)
+ f = torch.nn.functional.normalize(f, dim=-1)
+ lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
+ lpips_loss = sum(lpips_loss)
+
+
+ out = dict(
+ lpips_loss=lpips_loss.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py
new file mode 100644
index 0000000..580775b
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py
@@ -0,0 +1,81 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class SelfLPIPSTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ lpips_weight=1.0,
+ lpips_encoder_layer=4,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.lpips_encoder_layer = lpips_encoder_layer
+ self.lpips_weight = lpips_weight
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype)
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ out, _ = net(x_t, t, y)
+ pred_x0 = (x_t + out * sigma)
+ pred_xt = alpha * pred_x0 + noise * sigma
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+ with torch.no_grad():
+ _, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer)
+ _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer)
+
+
+ lpips_loss = []
+ for r, f in zip(real_features, fake_features):
+ r = torch.nn.functional.normalize(r, dim=-1)
+ f = torch.nn.functional.normalize(f, dim=-1)
+ lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean())
+ lpips_loss = sum(lpips_loss)
+
+
+ out = dict(
+ lpips_loss=lpips_loss.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + self.lpips_weight*lpips_loss.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/cm_sampling.py b/src/diffusion/stateful_flow_matching/cm_sampling.py
new file mode 100644
index 0000000..5254db5
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/cm_sampling.py
@@ -0,0 +1,78 @@
+import torch
+
+from src.diffusion.base.guidance import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+
+from typing import Callable
+
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+class CMSampler(BaseSampler):
+ def __init__(
+ self,
+ w_scheduler: BaseScheduler = None,
+ timeshift=1.0,
+ guidance_interval_min: float = 0.0,
+ guidance_interval_max: float = 1.0,
+ state_refresh_rate=1,
+ last_step=None,
+ step_fn=None,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.last_step = last_step
+ self.timeshift = timeshift
+ self.state_refresh_rate = state_refresh_rate
+ self.guidance_interval_min = guidance_interval_min
+ self.guidance_interval_max = guidance_interval_max
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
+
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ steps = self.timesteps.to(noise.device)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ state = None
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ cfg_t = t_cur.repeat(batch_size*2)
+ cfg_x = torch.cat([x, x], dim=0)
+ if i % self.state_refresh_rate == 0:
+ state = None
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
+ if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max:
+ out = self.guidance_fn(out, self.guidance)
+ else:
+ out = self.guidance_fn(out, 1.0)
+ v = out
+
+ x0 = x + v * (1-t_cur)
+ alpha_next = self.scheduler.alpha(t_next)
+ sigma_next = self.scheduler.sigma(t_next)
+ x = alpha_next * x0 + sigma_next * torch.randn_like(x)
+ # print(alpha_next, sigma_next)
+ return x
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/sampling.py b/src/diffusion/stateful_flow_matching/sampling.py
new file mode 100644
index 0000000..5fdfdb2
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/sampling.py
@@ -0,0 +1,103 @@
+import torch
+
+from src.diffusion.base.guidance import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+
+from typing import Callable
+
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+def sde_mean_step_fn(x, v, dt, s, w):
+ return x + v * dt + s * w * dt
+
+def sde_step_fn(x, v, dt, s, w):
+ return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x)
+
+def sde_preserve_step_fn(x, v, dt, s, w):
+ return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x)
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+class EulerSampler(BaseSampler):
+ def __init__(
+ self,
+ w_scheduler: BaseScheduler = None,
+ timeshift=1.0,
+ guidance_interval_min: float = 0.0,
+ guidance_interval_max: float = 1.0,
+ state_refresh_rate=1,
+ step_fn: Callable = ode_step_fn,
+ last_step=None,
+ last_step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.step_fn = step_fn
+ self.last_step = last_step
+ self.last_step_fn = last_step_fn
+ self.w_scheduler = w_scheduler
+ self.timeshift = timeshift
+ self.state_refresh_rate = state_refresh_rate
+ self.guidance_interval_min = guidance_interval_min
+ self.guidance_interval_max = guidance_interval_max
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
+
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ if self.w_scheduler is not None:
+ if self.step_fn == ode_step_fn:
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ steps = self.timesteps.to(noise.device)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ state = None
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ dt = t_next - t_cur
+ t_cur = t_cur.repeat(batch_size)
+ sigma = self.scheduler.sigma(t_cur)
+ dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
+ dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
+ if self.w_scheduler:
+ w = self.w_scheduler.w(t_cur)
+ else:
+ w = 0.0
+
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ if i % self.state_refresh_rate == 0:
+ state = None
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
+ if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
+ out = self.guidance_fn(out, self.guidance)
+ else:
+ out = self.guidance_fn(out, 1.0)
+ v = out
+ s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
+ if i < self.num_steps -1 :
+ x = self.step_fn(x, v, dt, s=s, w=w)
+ else:
+ x = self.last_step_fn(x, v, dt, s=s, w=w)
+ return x
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/scheduling.py b/src/diffusion/stateful_flow_matching/scheduling.py
new file mode 100644
index 0000000..a82cd3a
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/scheduling.py
@@ -0,0 +1,39 @@
+import math
+import torch
+from src.diffusion.base.scheduling import *
+
+
+class LinearScheduler(BaseScheduler):
+ def alpha(self, t) -> Tensor:
+ return (t).view(-1, 1, 1, 1)
+ def sigma(self, t) -> Tensor:
+ return (1-t).view(-1, 1, 1, 1)
+ def dalpha(self, t) -> Tensor:
+ return torch.full_like(t, 1.0).view(-1, 1, 1, 1)
+ def dsigma(self, t) -> Tensor:
+ return torch.full_like(t, -1.0).view(-1, 1, 1, 1)
+
+# SoTA for ImageNet!
+class GVPScheduler(BaseScheduler):
+ def alpha(self, t) -> Tensor:
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def sigma(self, t) -> Tensor:
+ return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def dalpha(self, t) -> Tensor:
+ return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def dsigma(self, t) -> Tensor:
+ return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1)
+ def w(self, t):
+ return torch.sin(t)**2
+
+class ConstScheduler(BaseScheduler):
+ def w(self, t):
+ return torch.ones(1, 1, 1, 1).to(t.device, t.dtype)
+
+from src.diffusion.ddpm.scheduling import VPScheduler
+class VPBetaScheduler(VPScheduler):
+ def w(self, t):
+ return self.beta(t).view(-1, 1, 1, 1)
+
+
+
diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py
new file mode 100644
index 0000000..f372028
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/sharing_sampling.py
@@ -0,0 +1,149 @@
+import torch
+
+from src.diffusion.base.guidance import *
+from src.diffusion.base.scheduling import *
+from src.diffusion.base.sampling import *
+
+from typing import Callable
+
+
+def shift_respace_fn(t, shift=3.0):
+ return t / (t + (1 - t) * shift)
+
+def ode_step_fn(x, v, dt, s, w):
+ return x + v * dt
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+class EulerSampler(BaseSampler):
+ def __init__(
+ self,
+ w_scheduler: BaseScheduler = None,
+ timeshift=1.0,
+ guidance_interval_min: float = 0.0,
+ guidance_interval_max: float = 1.0,
+ state_refresh_rate=1,
+ step_fn: Callable = ode_step_fn,
+ last_step=None,
+ last_step_fn: Callable = ode_step_fn,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.step_fn = step_fn
+ self.last_step = last_step
+ self.last_step_fn = last_step_fn
+ self.w_scheduler = w_scheduler
+ self.timeshift = timeshift
+ self.state_refresh_rate = state_refresh_rate
+ self.guidance_interval_min = guidance_interval_min
+ self.guidance_interval_max = guidance_interval_max
+
+ if self.last_step is None or self.num_steps == 1:
+ self.last_step = 1.0 / self.num_steps
+
+ timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps)
+ timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0)
+ self.timesteps = shift_respace_fn(timesteps, self.timeshift)
+
+ assert self.last_step > 0.0
+ assert self.scheduler is not None
+ assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
+ if self.w_scheduler is not None:
+ if self.step_fn == ode_step_fn:
+ logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")
+
+ # init recompute
+ self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate)
+ self.recompute_timesteps = list(range(self.num_steps))
+
+ def sharing_dp(self, net, noise, condition, uncondition):
+ _, C, H, W = noise.shape
+ B = 8
+ template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device)
+ template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device)
+ template_uncondition = torch.full((B, ), 1000, device=condition.device)
+ _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition)
+ states = torch.stack(state_list)
+ N, B, L, C = states.shape
+ states = states.view(N, B*L, C )
+ states = states.permute(1, 0, 2)
+ states = torch.nn.functional.normalize(states, dim=-1)
+ with torch.autocast(device_type="cuda", dtype=torch.float64):
+ sim = torch.bmm(states, states.transpose(1, 2))
+ sim = torch.mean(sim, dim=0).cpu()
+ error_map = (1-sim).tolist()
+
+ # init cum-error
+ for i in range(1, self.num_steps):
+ for j in range(0, i):
+ error_map[i][j] = error_map[i-1][j] + error_map[i][j]
+
+ # init dp and force 0 start
+ C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
+ P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)]
+ for i in range(1, self.num_steps+1):
+ C[1][i] = error_map[i - 1][0]
+ P[1][i] = 0
+
+ # dp state
+ for step in range(2, self.num_recompute_timesteps+1):
+ for i in range(step, self.num_steps+1):
+ min_value = 99999
+ min_index = -1
+ for j in range(step-1, i):
+ value = C[step-1][j] + error_map[i-1][j]
+ if value < min_value:
+ min_value = value
+ min_index = j
+ C[step][i] = min_value
+ P[step][i] = min_index
+
+ # trace back
+ timesteps = [self.num_steps,]
+ for i in range(self.num_recompute_timesteps, 0, -1):
+ idx = timesteps[-1]
+ timesteps.append(P[i][idx])
+ timesteps.reverse()
+
+ print("recompute timesteps solved by DP: ", timesteps)
+ return timesteps[:-1]
+
+ def _impl_sampling(self, net, noise, condition, uncondition):
+ """
+ sampling process of Euler sampler
+ -
+ """
+ batch_size = noise.shape[0]
+ steps = self.timesteps.to(noise.device)
+ cfg_condition = torch.cat([uncondition, condition], dim=0)
+ x = noise
+ state = None
+ pooled_state_list = []
+ for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):
+ dt = t_next - t_cur
+ t_cur = t_cur.repeat(batch_size)
+ cfg_x = torch.cat([x, x], dim=0)
+ cfg_t = t_cur.repeat(2)
+ if i in self.recompute_timesteps:
+ state = None
+ out, state = net(cfg_x, cfg_t, cfg_condition, state)
+ if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max:
+ out = self.guidance_fn(out, self.guidance)
+ else:
+ out = self.guidance_fn(out, 1.0)
+ v = out
+ if i < self.num_steps -1 :
+ x = self.step_fn(x, v, dt, s=0.0, w=0.0)
+ else:
+ x = self.last_step_fn(x, v, dt, s=0.0, w=0.0)
+ pooled_state_list.append(state)
+ return x, pooled_state_list
+
+ def __call__(self, net, noise, condition, uncondition):
+ if len(self.recompute_timesteps) != self.num_recompute_timesteps:
+ self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition)
+ denoised, _ = self._impl_sampling(net, noise, condition, uncondition)
+ return denoised
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training.py b/src/diffusion/stateful_flow_matching/training.py
new file mode 100644
index 0000000..4c49e1e
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training.py
@@ -0,0 +1,55 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class FlowMatchingTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ out, _ = net(x_t, t, y)
+
+ weight = self.loss_weight_fn(alpha, sigma)
+
+ loss = weight*(out - v_t)**2
+
+ out = dict(
+ loss=loss.mean(),
+ )
+ return out
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_adv.py b/src/diffusion/stateful_flow_matching/training_adv.py
new file mode 100644
index 0000000..4792950
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_adv.py
@@ -0,0 +1,122 @@
+import torch
+import math
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class Discriminator(nn.Module):
+ def __init__(self, in_channels, hidden_size):
+ super().__init__()
+ self.head = nn.Sequential(
+ nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4
+ nn.GroupNorm(num_groups=32, num_channels=hidden_size),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1
+ )
+
+ def forward(self, feature):
+ B, L, C = feature.shape
+ H = W = int(math.sqrt(L))
+ feature = feature.permute(0, 2, 1)
+ feature = feature.view(B, C, H, W)
+ out = self.head(feature).sigmoid().clamp(0.01, 0.99)
+ return out
+
+class AdvTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ adv_weight=1.0,
+ adv_encoder_layer=4,
+ adv_in_channels=768,
+ adv_hidden_size=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.adv_weight = adv_weight
+ self.adv_encoder_layer = adv_encoder_layer
+
+ self.dis_head = Discriminator(
+ in_channels=adv_in_channels,
+ hidden_size=adv_hidden_size,
+ )
+
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ adv_feature = []
+ def forward_hook(net, input, output):
+ adv_feature.append(output)
+ handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook)
+
+ out, _ = net(x_t, t, y)
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+
+ pred_x0 = (x_t + out * sigma)
+ pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma
+ real_feature = adv_feature.pop()
+ net(pred_xt, t, y, classify_layer=self.adv_encoder_layer)
+ fake_feature = adv_feature.pop()
+ handle.remove()
+
+
+ real_score_gan = self.dis_head(real_feature.detach())
+ fake_score_gan = self.dis_head(fake_feature.detach())
+ fake_score = self.dis_head(fake_feature)
+
+ loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan)
+ acc_real = (real_score_gan > 0.5).float()
+ acc_fake = (fake_score_gan < 0.5).float()
+ loss_adv = -torch.log(fake_score)
+ loss_adv_hack = torch.log(fake_score_gan)
+
+ out = dict(
+ adv_loss=loss_adv.mean(),
+ gan_loss=loss_gan.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(),
+ acc_real=acc_real.mean(),
+ acc_fake=acc_fake.mean(),
+ )
+ return out
diff --git a/src/diffusion/stateful_flow_matching/training_distill_dino.py b/src/diffusion/stateful_flow_matching/training_distill_dino.py
new file mode 100644
index 0000000..c6a2937
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_distill_dino.py
@@ -0,0 +1,141 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class DistillDINOTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.encoder = DINOv2(encoder_weight_path)
+ self.proj_encoder_dim = proj_encoder_dim
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ sigma = self.scheduler.sigma(t)
+
+ x_t = alpha * x + noise * sigma
+
+ _, s = net(x_t, t, y)
+ src_feature = self.proj(s)
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ if dst_feature.shape[1] != src_feature.shape[1]:
+ dst_length = dst_feature.shape[1]
+ rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
+ dst_height = (dst_length)**0.5 * (height/width)**0.5
+ dst_width = (dst_length)**0.5 * (width/height)**0.5
+ dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
+ dst_feature = dst_feature.permute(0, 3, 1, 2)
+ dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
+ dst_feature = dst_feature.permute(0, 2, 3, 1)
+ dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ out = dict(
+ cos_loss=cos_loss.mean(),
+ loss=cos_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/stateful_flow_matching/training_lpips.py b/src/diffusion/stateful_flow_matching/training_lpips.py
new file mode 100644
index 0000000..a3cd2a2
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_lpips.py
@@ -0,0 +1,71 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class LPIPSTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ lpips_weight=1.0,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.lpips_weight = lpips_weight
+ self.lpips = _NoTrainLpips(net="vgg")
+ self.lpips = self.lpips.to(torch.bfloat16)
+ # self.lpips = torch.compile(self.lpips)
+ no_grad(self.lpips)
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ out, _ = net(x_t, t, y)
+ weight = self.loss_weight_fn(alpha, sigma)
+ loss = weight*(out - v_t)**2
+
+ pred_x0 = (x_t + out*sigma)
+ target_x0 = x
+ # fixbug lpips std
+ lpips = self.lpips(pred_x0*0.5, target_x0*0.5)
+
+ out = dict(
+ lpips_loss=lpips.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + lpips.mean()*self.lpips_weight,
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
+ return
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py
new file mode 100644
index 0000000..e0233ea
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py
@@ -0,0 +1,74 @@
+import torch
+from typing import Callable
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from src.utils.no_grad import no_grad
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+class LPIPSTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ lognorm_t=False,
+ lpips_weight=1.0,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = False
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.lpips_weight = lpips_weight
+ self.lpips = _NoTrainLpips(net="vgg")
+ self.lpips = self.lpips.to(torch.bfloat16)
+ # self.lpips = torch.compile(self.lpips)
+ no_grad(self.lpips)
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size = x.shape[0]
+ if self.lognorm_t:
+ t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid()
+ else:
+ t = torch.rand(batch_size).to(x.device, x.dtype)
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+ w = self.scheduler.w(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ out, _ = net(x_t, t, y)
+
+ fm_weight = t*(1-t)**2/0.25
+ lpips_weight = t
+
+ loss = (out - v_t)**2 * fm_weight[:, None, None, None]
+
+ pred_x0 = (x_t + out*sigma)
+ target_x0 = x
+ # fixbug lpips std
+ lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None]
+
+ out = dict(
+ lpips_loss=lpips.mean(),
+ fm_loss=loss.mean(),
+ loss=loss.mean() + lpips.mean()*self.lpips_weight,
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
+ return
\ No newline at end of file
diff --git a/src/diffusion/stateful_flow_matching/training_repa.py b/src/diffusion/stateful_flow_matching/training_repa.py
new file mode 100644
index 0000000..a5a28db
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_repa.py
@@ -0,0 +1,157 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class REPATrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.encoder = DINOv2(encoder_weight_path)
+ self.proj_encoder_dim = proj_encoder_dim
+ no_grad(self.encoder)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+
+ if getattr(net, "blocks", None) is not None:
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+ else:
+ handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+
+ out, _ = net(x_t, t, y)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ if dst_feature.shape[1] != src_feature.shape[1]:
+ dst_length = dst_feature.shape[1]
+ rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
+ dst_height = (dst_length)**0.5 * (height/width)**0.5
+ dst_width = (dst_length)**0.5 * (width/height)**0.5
+ dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
+ dst_feature = dst_feature.permute(0, 3, 1, 2)
+ dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
+ dst_feature = dst_feature.permute(0, 2, 3, 1)
+ dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = weight*(out - v_t)**2
+
+ out = dict(
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/diffusion/stateful_flow_matching/training_repa_lpips.py b/src/diffusion/stateful_flow_matching/training_repa_lpips.py
new file mode 100644
index 0000000..5a11207
--- /dev/null
+++ b/src/diffusion/stateful_flow_matching/training_repa_lpips.py
@@ -0,0 +1,170 @@
+import torch
+import copy
+import timm
+from torch.nn import Parameter
+
+from src.utils.no_grad import no_grad
+from typing import Callable, Iterator, Tuple
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+from src.diffusion.base.training import *
+from src.diffusion.base.scheduling import BaseScheduler
+from torchmetrics.image.lpip import _NoTrainLpips
+
+def inverse_sigma(alpha, sigma):
+ return 1/sigma**2
+def snr(alpha, sigma):
+ return alpha/sigma
+def minsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, min=threshold)
+def maxsnr(alpha, sigma, threshold=5):
+ return torch.clip(alpha/sigma, max=threshold)
+def constant(alpha, sigma):
+ return 1
+
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load(
+ '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main',
+ weight_path,
+ source="local",
+ skip_validation=True
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ return feature
+
+
+class REPALPIPSTrainer(BaseTrainer):
+ def __init__(
+ self,
+ scheduler: BaseScheduler,
+ loss_weight_fn:Callable=constant,
+ feat_loss_weight: float=0.5,
+ lognorm_t=False,
+ lpips_weight=1.0,
+ encoder_weight_path=None,
+ align_layer=8,
+ proj_denoiser_dim=256,
+ proj_hidden_dim=256,
+ proj_encoder_dim=256,
+ *args,
+ **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.lognorm_t = lognorm_t
+ self.scheduler = scheduler
+ self.loss_weight_fn = loss_weight_fn
+ self.feat_loss_weight = feat_loss_weight
+ self.align_layer = align_layer
+ self.encoder = DINOv2(encoder_weight_path)
+ self.proj_encoder_dim = proj_encoder_dim
+ no_grad(self.encoder)
+
+ self.lpips_weight = lpips_weight
+ self.lpips = _NoTrainLpips(net="vgg")
+ self.lpips = self.lpips.to(torch.bfloat16)
+ no_grad(self.lpips)
+
+ self.proj = nn.Sequential(
+ nn.Sequential(
+ nn.Linear(proj_denoiser_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_hidden_dim),
+ nn.SiLU(),
+ nn.Linear(proj_hidden_dim, proj_encoder_dim),
+ )
+ )
+
+ def _impl_trainstep(self, net, ema_net, raw_images, x, y):
+ batch_size, c, height, width = x.shape
+ if self.lognorm_t:
+ base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid()
+ else:
+ base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype)
+ t = base_t
+
+ noise = torch.randn_like(x)
+ alpha = self.scheduler.alpha(t)
+ dalpha = self.scheduler.dalpha(t)
+ sigma = self.scheduler.sigma(t)
+ dsigma = self.scheduler.dsigma(t)
+
+ x_t = alpha * x + noise * sigma
+ v_t = dalpha * x + dsigma * noise
+
+ src_feature = []
+ def forward_hook(net, input, output):
+ src_feature.append(output)
+ if getattr(net, "blocks", None) is not None:
+ handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+ else:
+ handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook)
+
+ out, _ = net(x_t, t, y)
+ src_feature = self.proj(src_feature[0])
+ handle.remove()
+
+ with torch.no_grad():
+ dst_feature = self.encoder(raw_images)
+
+ if dst_feature.shape[1] != src_feature.shape[1]:
+ dst_length = dst_feature.shape[1]
+ rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5
+ dst_height = (dst_length)**0.5 * (height/width)**0.5
+ dst_width = (dst_length)**0.5 * (width/height)**0.5
+ dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim)
+ dst_feature = dst_feature.permute(0, 3, 1, 2)
+ dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False)
+ dst_feature = dst_feature.permute(0, 2, 3, 1)
+ dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim)
+
+ cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1)
+ cos_loss = 1 - cos_sim
+
+ weight = self.loss_weight_fn(alpha, sigma)
+ fm_loss = weight*(out - v_t)**2
+
+ pred_x0 = (x_t + out * sigma)
+ target_x0 = x
+ # fixbug lpips std
+ lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5)
+
+ out = dict(
+ lpips_loss=lpips.mean(),
+ fm_loss=fm_loss.mean(),
+ cos_loss=cos_loss.mean(),
+ loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(),
+ )
+ return out
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ self.proj.state_dict(
+ destination=destination,
+ prefix=prefix + "proj.",
+ keep_vars=keep_vars)
+
diff --git a/src/lightning_data.py b/src/lightning_data.py
new file mode 100644
index 0000000..9f75a42
--- /dev/null
+++ b/src/lightning_data.py
@@ -0,0 +1,162 @@
+from typing import Any
+import torch
+import copy
+import lightning.pytorch as pl
+from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
+from torch.utils.data import DataLoader
+from src.data.dataset.randn import RandomNDataset
+from src.data.var_training import VARTransformEngine
+
+def collate_fn(batch):
+ new_batch = copy.deepcopy(batch)
+ new_batch = list(zip(*new_batch))
+ for i in range(len(new_batch)):
+ if isinstance(new_batch[i][0], torch.Tensor):
+ try:
+ new_batch[i] = torch.stack(new_batch[i], dim=0)
+ except:
+ print("Warning: could not stack tensors")
+ return new_batch
+
+class DataModule(pl.LightningDataModule):
+ def __init__(self,
+ train_root,
+ test_nature_root,
+ test_gen_root,
+ train_image_size=64,
+ train_batch_size=64,
+ train_num_workers=8,
+ var_transform_engine: VARTransformEngine = None,
+ train_prefetch_factor=2,
+ train_dataset: str = None,
+ eval_batch_size=32,
+ eval_num_workers=4,
+ eval_max_num_instances=50000,
+ pred_batch_size=32,
+ pred_num_workers=4,
+ pred_seeds:str=None,
+ pred_selected_classes=None,
+ num_classes=1000,
+ latent_shape=(4,64,64),
+ ):
+ super().__init__()
+ pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None
+
+ self.train_root = train_root
+ self.train_image_size = train_image_size
+ self.train_dataset = train_dataset
+ # stupid data_convert override, just to make nebular happy
+ self.train_batch_size = train_batch_size
+ self.train_num_workers = train_num_workers
+ self.train_prefetch_factor = train_prefetch_factor
+
+ self.test_nature_root = test_nature_root
+ self.test_gen_root = test_gen_root
+ self.eval_max_num_instances = eval_max_num_instances
+ self.pred_seeds = pred_seeds
+ self.num_classes = num_classes
+ self.latent_shape = latent_shape
+
+ self.eval_batch_size = eval_batch_size
+ self.pred_batch_size = pred_batch_size
+
+ self.pred_num_workers = pred_num_workers
+ self.eval_num_workers = eval_num_workers
+
+ self.pred_selected_classes = pred_selected_classes
+
+ self._train_dataloader = None
+ self.var_transform_engine = var_transform_engine
+
+ def setup(self, stage: str) -> None:
+ if stage == "fit":
+ assert self.train_dataset is not None
+ if self.train_dataset == "pix_imagenet64":
+ from src.data.dataset.imagenet import PixImageNet64
+ self.train_dataset = PixImageNet64(
+ root=self.train_root,
+ )
+ elif self.train_dataset == "pix_imagenet128":
+ from src.data.dataset.imagenet import PixImageNet128
+ self.train_dataset = PixImageNet128(
+ root=self.train_root,
+ )
+ elif self.train_dataset == "imagenet256":
+ from src.data.dataset.imagenet import ImageNet256
+ self.train_dataset = ImageNet256(
+ root=self.train_root,
+ )
+ elif self.train_dataset == "pix_imagenet256":
+ from src.data.dataset.imagenet import PixImageNet256
+ self.train_dataset = PixImageNet256(
+ root=self.train_root,
+ )
+ elif self.train_dataset == "imagenet512":
+ from src.data.dataset.imagenet import ImageNet512
+ self.train_dataset = ImageNet512(
+ root=self.train_root,
+ )
+ elif self.train_dataset == "pix_imagenet512":
+ from src.data.dataset.imagenet import PixImageNet512
+ self.train_dataset = PixImageNet512(
+ root=self.train_root,
+ )
+ else:
+ raise NotImplementedError("no such dataset")
+
+ def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
+ if self.var_transform_engine and self.trainer.training:
+ batch = self.var_transform_engine(batch)
+ return batch
+
+ def train_dataloader(self) -> TRAIN_DATALOADERS:
+ global_rank = self.trainer.global_rank
+ world_size = self.trainer.world_size
+ from torch.utils.data import DistributedSampler
+ sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
+ self._train_dataloader = DataLoader(
+ self.train_dataset,
+ self.train_batch_size,
+ timeout=6000,
+ num_workers=self.train_num_workers,
+ prefetch_factor=self.train_prefetch_factor,
+ sampler=sampler,
+ collate_fn=collate_fn,
+ )
+ return self._train_dataloader
+
+ def val_dataloader(self) -> EVAL_DATALOADERS:
+ global_rank = self.trainer.global_rank
+ world_size = self.trainer.world_size
+ self.eval_dataset = RandomNDataset(
+ latent_shape=self.latent_shape,
+ num_classes=self.num_classes,
+ max_num_instances=self.eval_max_num_instances,
+ )
+ from torch.utils.data import DistributedSampler
+ sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
+ return DataLoader(self.eval_dataset, self.eval_batch_size,
+ num_workers=self.eval_num_workers,
+ prefetch_factor=2,
+ collate_fn=collate_fn,
+ sampler=sampler
+ )
+
+ def predict_dataloader(self) -> EVAL_DATALOADERS:
+ global_rank = self.trainer.global_rank
+ world_size = self.trainer.world_size
+ self.pred_dataset = RandomNDataset(
+ seeds= self.pred_seeds,
+ max_num_instances=50000,
+ num_classes=self.num_classes,
+ selected_classes=self.pred_selected_classes,
+ latent_shape=self.latent_shape,
+ )
+ from torch.utils.data import DistributedSampler
+ sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False)
+ return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size,
+ num_workers=self.pred_num_workers,
+ prefetch_factor=4,
+ collate_fn=collate_fn,
+ sampler=sampler
+ )
diff --git a/src/lightning_model.py b/src/lightning_model.py
new file mode 100644
index 0000000..4602e82
--- /dev/null
+++ b/src/lightning_model.py
@@ -0,0 +1,123 @@
+from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict
+import os.path
+import copy
+import torch
+import torch.nn as nn
+import lightning.pytorch as pl
+from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT
+from torch.optim.lr_scheduler import LRScheduler
+from torch.optim import Optimizer
+from lightning.pytorch.callbacks import Callback
+
+
+from src.models.vae import BaseVAE, fp2uint8
+from src.models.conditioner import BaseConditioner
+from src.utils.model_loader import ModelLoader
+from src.callbacks.simple_ema import SimpleEMA
+from src.diffusion.base.sampling import BaseSampler
+from src.diffusion.base.training import BaseTrainer
+from src.utils.no_grad import no_grad, filter_nograd_tensors
+from src.utils.copy import copy_params
+
+EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA]
+OptimizerCallable = Callable[[Iterable], Optimizer]
+LRSchedulerCallable = Callable[[Optimizer], LRScheduler]
+
+
+class LightningModel(pl.LightningModule):
+ def __init__(self,
+ vae: BaseVAE,
+ conditioner: BaseConditioner,
+ denoiser: nn.Module,
+ diffusion_trainer: BaseTrainer,
+ diffusion_sampler: BaseSampler,
+ ema_tracker: Optional[EMACallable] = None,
+ optimizer: OptimizerCallable = None,
+ lr_scheduler: LRSchedulerCallable = None,
+ ):
+ super().__init__()
+ self.vae = vae
+ self.conditioner = conditioner
+ self.denoiser = denoiser
+ self.ema_denoiser = copy.deepcopy(self.denoiser)
+ self.diffusion_sampler = diffusion_sampler
+ self.diffusion_trainer = diffusion_trainer
+ self.ema_tracker = ema_tracker
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ # self.model_loader = ModelLoader()
+
+ self._strict_loading = False
+
+ def configure_model(self) -> None:
+ self.trainer.strategy.barrier()
+ # self.denoiser = self.model_loader.load(self.denoiser)
+ copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser)
+
+ # self.denoiser = torch.compile(self.denoiser)
+ # disable grad for conditioner and vae
+ no_grad(self.conditioner)
+ no_grad(self.vae)
+ no_grad(self.diffusion_sampler)
+ no_grad(self.ema_denoiser)
+
+ def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
+ ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser)
+ return [ema_tracker]
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ params_denoiser = filter_nograd_tensors(self.denoiser.parameters())
+ params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters())
+ optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser])
+ if self.lr_scheduler is None:
+ return dict(
+ optimizer=optimizer
+ )
+ else:
+ lr_scheduler = self.lr_scheduler(optimizer)
+ return dict(
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler
+ )
+
+ def training_step(self, batch, batch_idx):
+ raw_images, x, y = batch
+ with torch.no_grad():
+ x = self.vae.encode(x)
+ condition, uncondition = self.conditioner(y)
+ loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition)
+ self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False)
+ return loss["loss"]
+
+ def predict_step(self, batch, batch_idx):
+ xT, y, metadata = batch
+ with torch.no_grad():
+ condition, uncondition = self.conditioner(y)
+ # Sample images:
+ samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition)
+ samples = self.vae.decode(samples)
+ # fp32 -1,1 -> uint8 0,255
+ samples = fp2uint8(samples)
+ return samples
+
+ def validation_step(self, batch, batch_idx):
+ samples = self.predict_step(batch, batch_idx)
+ return samples
+
+ def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = {}
+ self._save_to_state_dict(destination, prefix, keep_vars)
+ self.denoiser.state_dict(
+ destination=destination,
+ prefix=prefix+"denoiser.",
+ keep_vars=keep_vars)
+ self.ema_denoiser.state_dict(
+ destination=destination,
+ prefix=prefix+"ema_denoiser.",
+ keep_vars=keep_vars)
+ self.diffusion_trainer.state_dict(
+ destination=destination,
+ prefix=prefix+"diffusion_trainer.",
+ keep_vars=keep_vars)
+ return destination
\ No newline at end of file
diff --git a/src/models/__init__.py b/src/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/models/conditioner.py b/src/models/conditioner.py
new file mode 100644
index 0000000..a68fad3
--- /dev/null
+++ b/src/models/conditioner.py
@@ -0,0 +1,26 @@
+import torch
+import torch.nn as nn
+
+class BaseConditioner(nn.Module):
+ def __init__(self):
+ super(BaseConditioner, self).__init__()
+
+ def _impl_condition(self, y):
+ ...
+ def _impl_uncondition(self, y):
+ ...
+ def __call__(self, y):
+ condition = self._impl_condition(y)
+ uncondition = self._impl_uncondition(y)
+ return condition, uncondition
+
+class LabelConditioner(BaseConditioner):
+ def __init__(self, null_class):
+ super().__init__()
+ self.null_condition = null_class
+
+ def _impl_condition(self, y):
+ return torch.tensor(y).long().cuda()
+
+ def _impl_uncondition(self, y):
+ return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()
\ No newline at end of file
diff --git a/src/models/denoiser/__init__.py b/src/models/denoiser/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py
new file mode 100644
index 0000000..3581446
--- /dev/null
+++ b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py
@@ -0,0 +1,383 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+from src.utils.model_loader import ModelLoader
+from src.utils.no_grad import no_grad
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiTEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
+ self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
+ return (pos_rope, pos_ape)
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, mask=None):
+ B, _, H, W = x.shape
+ pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ s = self.s_embedder(x)
+ # s = s + pos_ape.to(s.dtype)
+ for i in range(self.num_blocks):
+ s = self.blocks[i](s, c, pos_rope, mask)
+ return None, s
+
+
+class FlattenDiTDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
+ self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y, s, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
+ c = torch.nn.functional.silu(t + y)
+ x = torch.cat([x, s], dim=-1)
+ x = self.x_embedder(x)
+ for i in range(self.num_blocks):
+ x = self.blocks[i](x, c, pos, None)
+ x = self.final_layer(x, c)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
+ stride=self.patch_size)
+ return x
+
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ ModelLoader().load(encoder)
+ self.encoder = self.encoder.to(torch.bfloat16)
+ no_grad(self.encoder)
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ with torch.no_grad():
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py
new file mode 100644
index 0000000..733ce4a
--- /dev/null
+++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py
@@ -0,0 +1,447 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+from src.utils.model_loader import ModelLoader
+from src.utils.no_grad import no_grad
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+class ResBlock(nn.Module):
+ def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
+ super().__init__()
+ self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.norm1 = nn.GroupNorm(groups, dim)
+ self.norm2 = nn.GroupNorm(groups, dim)
+ self.embed_proj = nn.Linear(hidden_dim, dim)
+
+ def forward(self, x, c):
+ c = self.embed_proj(c)[:, :, None, None]
+ residual = x
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = torch.nn.functional.silu(x)
+ x = x * c
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = torch.nn.functional.silu(x)
+ return residual + x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiTEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
+ self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
+ return (pos_rope, pos_ape)
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, mask=None, classify_layer=None):
+ B, _, H, W = x.shape
+ pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ s = self.s_embedder(x)
+ # s = s + pos_ape.to(s.dtype)
+ classify_feats = []
+ for i in range(self.num_blocks):
+ s = self.blocks[i](s, c, pos_rope, mask)
+ if classify_layer is not None and i < classify_layer:
+ classify_feats.append(s)
+ if i == classify_layer - 1:
+ return _, classify_feats
+ return None, s
+
+
+class FlattenDiTDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_mid_blocks=18,
+ num_res_blocks=[1, 1, 1],
+ num_res_channels=[64, 384, 768],
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_mid_blocks = num_mid_blocks
+ self.num_res_blocks = num_res_blocks
+ self.num_res_channels = num_res_channels
+ self.patch_size = 2**(len(num_res_blocks))
+
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.down_res_blocks = nn.ModuleList()
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.down_res_blocks.append(
+ nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
+ )
+ self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = []
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.up_res_blocks.append(
+ nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
+ )
+ self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
+
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
+ ])
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+ self.weight_path = weight_path
+ self.load_ema = load_ema
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, s, mask=None):
+ B, _, H, W = x.shape
+ t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
+ y = self.y_embedder(y).view(B, self.hidden_size)
+ s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
+ c = torch.nn.functional.silu(t + y)
+
+ residual = []
+ for i, block in enumerate(self.down_res_blocks):
+ if isinstance(block, nn.Conv2d):
+ residual.append(x)
+ x = block(x)
+ else:
+ x = block(x, c)
+
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = x.view(B, self.hidden_size, -1).transpose(1, 2)
+ mid_c = torch.nn.functional.silu(t[:, None, :] + s)
+ for i in range(self.num_mid_blocks):
+ x = self.blocks[i](x, mid_c, pos, None)
+ x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
+
+ residual[0] = 0.0
+ for i, block in enumerate(self.up_res_blocks):
+ if isinstance(block, nn.ConvTranspose2d):
+ x = block(x) + residual.pop()
+ else:
+ x = block(x, c)
+ return x
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ ModelLoader().load(encoder)
+ self.encoder = self.encoder.to(torch.bfloat16)
+ no_grad(self.encoder)
+
+ def forward(self, x, t, y, s=None, classify_layer=None):
+ if s is None:
+ _, s = self.encoder(x, t, y, classify_layer=classify_layer)
+ if classify_layer is not None:
+ return None, s
+ x = self.decoder(x, t, y, s)
+ return x, s
+
+class FlattenDiT_jointtraining(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py
new file mode 100644
index 0000000..6e9adbc
--- /dev/null
+++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py
@@ -0,0 +1,448 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+from src.utils.model_loader import ModelLoader
+from src.utils.no_grad import no_grad
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+class ResBlock(nn.Module):
+ def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
+ super().__init__()
+ self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.norm1 = nn.GroupNorm(groups, dim)
+ self.norm2 = nn.GroupNorm(groups, dim)
+ self.embed_proj = nn.Linear(hidden_dim, dim)
+
+ def forward(self, x, c):
+ c = self.embed_proj(c)[:, :, None, None]
+ residual = x
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = torch.nn.functional.silu(x)
+ x = x * c
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = torch.nn.functional.silu(x)
+ return residual + x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiTEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
+ self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
+ return (pos_rope, pos_ape)
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, mask=None, classify_layer=None):
+ B, _, H, W = x.shape
+ pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ s = self.s_embedder(x)
+ # s = s + pos_ape.to(s.dtype)
+ classify_feats = []
+ for i in range(self.num_blocks):
+ s = self.blocks[i](s, c, pos_rope, mask)
+ if classify_layer is not None and i < classify_layer:
+ classify_feats.append(s)
+ if i == classify_layer - 1:
+ return _, classify_feats
+ return None, s
+
+
+class FlattenDiTDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_mid_blocks=18,
+ num_res_blocks=[1, 1, 1],
+ num_res_channels=[64, 384, 768],
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_mid_blocks = num_mid_blocks
+ self.num_res_blocks = num_res_blocks
+ self.num_res_channels = num_res_channels
+ self.patch_size = 2**(len(num_res_blocks))
+
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.down_res_blocks = nn.ModuleList()
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.down_res_blocks.append(
+ nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
+ )
+ self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = []
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.up_res_blocks.append(
+ nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
+ )
+ self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
+
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
+ ])
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+ self.weight_path = weight_path
+ self.load_ema = load_ema
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, s, mask=None):
+ B, _, H, W = x.shape
+ t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
+ y = self.y_embedder(y).view(B, self.hidden_size)
+ s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
+ c = torch.nn.functional.silu(t + y)
+
+ residual = []
+ for i, block in enumerate(self.down_res_blocks):
+ if isinstance(block, nn.Conv2d):
+ residual.append(x)
+ x = block(x)
+ else:
+ x = block(x, c)
+
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = x.view(B, self.hidden_size, -1).transpose(1, 2)
+ mid_c = torch.nn.functional.silu(t[:, None, :] + s)
+ for i in range(self.num_mid_blocks):
+ x = self.blocks[i](x, mid_c, pos, None)
+ x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
+
+ residual[0] = 0.0
+ for i, block in enumerate(self.up_res_blocks):
+ if isinstance(block, nn.ConvTranspose2d):
+ x = block(x) + residual.pop()
+ else:
+ x = block(x, c)
+ return x
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ ModelLoader().load(encoder, "encoder.")
+ ModelLoader().load(decoder, "decoder.")
+ self.encoder = self.encoder.to(torch.bfloat16)
+ no_grad(self.encoder)
+
+ def forward(self, x, t, y, s=None, classify_layer=None):
+ if s is None:
+ _, s = self.encoder(x, t, y, classify_layer=classify_layer)
+ if classify_layer is not None:
+ return None, s
+ x = self.decoder(x, t, y, s)
+ return x, s
+
+class FlattenDiT_jointtraining(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py
new file mode 100644
index 0000000..537078a
--- /dev/null
+++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py
@@ -0,0 +1,464 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+from src.utils.model_loader import ModelLoader
+from src.utils.no_grad import no_grad
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+class ResBlock(nn.Module):
+ def __init__(self, dim:int, groups:int=8, hidden_dim:int=256):
+ super().__init__()
+ self.conv1 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
+ self.norm1 = nn.GroupNorm(groups, dim)
+ self.norm2 = nn.GroupNorm(groups, dim)
+
+ def forward(self, x):
+ residual = x
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = torch.nn.functional.silu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = torch.nn.functional.silu(x)
+ return residual + x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiTEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
+ self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
+ return (pos_rope, pos_ape)
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, mask=None, classify_layer=None):
+ B, _, H, W = x.shape
+ pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ s = self.s_embedder(x)
+ # s = s + pos_ape.to(s.dtype)
+ classify_feats = []
+ for i in range(self.num_blocks):
+ s = self.blocks[i](s, c, pos_rope, mask)
+ if classify_layer is not None and i < classify_layer:
+ classify_feats.append(s)
+ if i == classify_layer - 1:
+ return _, classify_feats
+ return None, s
+
+
+class FlattenDiTDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_mid_blocks=18,
+ num_res_blocks=[1, 1, 1],
+ num_res_channels=[64, 384, 768],
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_mid_blocks = num_mid_blocks
+ self.num_res_blocks = num_res_blocks
+ self.num_res_channels = num_res_channels
+ self.patch_size = 2**(len(num_res_blocks))
+
+ self.t_embedder = TimestepEmbedder(hidden_size)
+
+ self.down_res_blocks = nn.ModuleList()
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.down_res_blocks.append(
+ nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0),
+ )
+ self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = []
+ previous_channel = self.in_channels
+ for num, channels in zip(num_res_blocks, num_res_channels):
+ self.up_res_blocks.append(
+ nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0)
+ )
+ self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)])
+ previous_channel = channels
+ self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1])
+
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks)
+ ])
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+ self.weight_path = weight_path
+ self.load_ema = load_ema
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+ # Zero-out adaLN modulation layers in SiT blocks:
+ for block in self.blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ for block in self.down_res_blocks:
+ if isinstance(block, ResBlock):
+ nn.init.constant_(block.conv1.weight, 0)
+ nn.init.constant_(block.conv1.bias, 0)
+ nn.init.constant_(block.norm1.weight, 0)
+ nn.init.constant_(block.norm2.weight, 0)
+ nn.init.constant_(block.conv2.weight, 0)
+ nn.init.constant_(block.conv2.bias, 0)
+
+ for block in self.up_res_blocks:
+ if isinstance(block, ResBlock):
+ nn.init.constant_(block.conv1.weight, 0)
+ nn.init.constant_(block.conv1.bias, 0)
+ nn.init.constant_(block.norm1.weight, 0)
+ nn.init.constant_(block.norm2.weight, 0)
+ nn.init.constant_(block.conv2.weight, 0)
+ nn.init.constant_(block.conv2.bias, 0)
+
+ def forward(self, x, t, y, s, mask=None):
+ B, _, H, W = x.shape
+ t = self.t_embedder(t.view(-1)).view(B, self.hidden_size)
+ s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
+
+ residual = []
+ for i, block in enumerate(self.down_res_blocks):
+ if isinstance(block, nn.Conv2d):
+ residual.append(x)
+ x = block(x)
+ else:
+ x = block(x)
+
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = x.view(B, self.hidden_size, -1).transpose(1, 2)
+ mid_c = torch.nn.functional.silu(t[:, None, :] + s)
+ for i in range(self.num_mid_blocks):
+ x = self.blocks[i](x, mid_c, pos, None)
+ x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size)
+
+ residual[0] = 0.0
+ for i, block in enumerate(self.up_res_blocks):
+ if isinstance(block, nn.ConvTranspose2d):
+ x = block(x) + residual.pop()
+ else:
+ x = block(x)
+ return x
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ ModelLoader().load(encoder, "encoder.")
+ ModelLoader().load(decoder, "decoder.")
+ self.encoder = self.encoder.to(torch.bfloat16)
+ no_grad(self.encoder)
+
+ def forward(self, x, t, y, s=None, classify_layer=None):
+ if s is None:
+ _, s = self.encoder(x, t, y, classify_layer=classify_layer)
+ if classify_layer is not None:
+ return None, s
+ x = self.decoder(x, t, y, s)
+ return x, s
+
+class FlattenDiT_jointtraining(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
diff --git a/src/models/denoiser/condit_dit.py b/src/models/denoiser/condit_dit.py
new file mode 100644
index 0000000..48d6b0e
--- /dev/null
+++ b/src/models/denoiser/condit_dit.py
@@ -0,0 +1,274 @@
+import torch
+import torch.nn as nn
+import math
+
+from numba.cuda.cudadrv.devicearray import lru_cache
+from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+from torch.nn.attention import SDPBackend, sdpa_kernel
+
+flex_attention = torch.compile(flex_attention)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = self.norm(x)
+ x = modulate(x, shift, scale)
+ x = self.linear(x)
+ return x
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ self.fc1 = nn.Linear(dim, hidden_dim)
+ self.fc2 = nn.Linear(hidden_dim, dim)
+ self.act = nn.GELU(approximate="tanh")
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16):
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ x_pos = x_pos.reshape(-1)
+ y_pos = y_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1)
+ return freqs_cis
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ # import pdb; pdb.set_trace()
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q).to(q.dtype)
+ k = self.k_norm(k).to(k.dtype)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+ # x = flex_attention(q, k, v, block_mask=mask)
+ # with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
+ x = scaled_dot_product_attention(q, k, v, mask)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class DiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False)
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class ConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ out_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size)
+ self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2)
+ self.num_cond_blocks = num_cond_blocks
+
+
+ self.weight_path = weight_path
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+
+ @lru_cache
+ def fetch_pos(self, height, width, device):
+ pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...]
+ return pos
+
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y, s=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H, W, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+
+ if s is None:
+ # semantic encoder
+ s = self.s_embedder(x) + pos
+ c = nn.functional.silu(t + y)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
diff --git a/src/models/denoiser/flatten_condit_catdit_fixt.py b/src/models/denoiser/flatten_condit_catdit_fixt.py
new file mode 100644
index 0000000..22a0fd5
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_catdit_fixt.py
@@ -0,0 +1,314 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, mask)
+ # s = nn.functional.silu(t + s)
+ s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
+ x = torch.cat((x, s), dim=-1)
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, c, pos, None)
+ x = self.final_layer(x, c)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_conv_fixt.py b/src/models/denoiser/flatten_condit_conv_fixt.py
new file mode 100644
index 0000000..219db4c
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_conv_fixt.py
@@ -0,0 +1,340 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+class FlattenConvBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3):
+ super().__init__()
+ self.hidden_size = hidden_size
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
+ attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
+ attn_x = self.attn(attn_x)
+ attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
+ x = x + gate_msa * attn_x
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ kernel_size=3,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([])
+ for i in range(self.num_cond_blocks):
+ self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
+ for i in range(self.num_blocks-self.num_cond_blocks):
+ self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size))
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, None)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_convnext_fixt.py b/src/models/denoiser/flatten_condit_convnext_fixt.py
new file mode 100644
index 0000000..cf9c214
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_convnext_fixt.py
@@ -0,0 +1,339 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+class FlattenConvBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ attn_x = modulate(self.norm1(x), shift_msa, scale_msa)
+ attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous()
+ attn_x = self.attn(attn_x)
+ attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2)
+ x = x + gate_msa * attn_x
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([])
+ for i in range(self.num_cond_blocks):
+ self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
+ for i in range(self.num_blocks-self.num_cond_blocks):
+ self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups))
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, None)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_dit_fixt.py b/src/models/denoiser/flatten_condit_dit_fixt.py
new file mode 100644
index 0000000..15557f3
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_dit_fixt.py
@@ -0,0 +1,313 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, mask)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_dit_norm_fixt.py b/src/models/denoiser/flatten_condit_dit_norm_fixt.py
new file mode 100644
index 0000000..28034e3
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_dit_norm_fixt.py
@@ -0,0 +1,314 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, mask)
+ s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
new file mode 100644
index 0000000..9a5e4fd
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py
@@ -0,0 +1,429 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+from src.utils.model_loader import ModelLoader
+from src.utils.no_grad import no_grad
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiTEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device)
+ self.precompute_pos[(height, width)] = (pos_rope, pos_ape)
+ return (pos_rope, pos_ape)
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ def forward(self, x, t, y, mask=None):
+ B, _, H, W = x.shape
+ pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ s = self.s_embedder(x)
+ # s = s + pos_ape.to(s.dtype)
+ for i in range(self.num_blocks):
+ s = self.blocks[i](s, c, pos_rope, mask)
+ return None, s
+
+
+class FlattenDiTDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)]
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y, s, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ # s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6)
+ s = torch.nn.functional.silu(t + s)
+ x = self.x_embedder(x)
+ for i in range(self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size,
+ stride=self.patch_size)
+ return x
+
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ joint_training=False,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+
+ ModelLoader().load(encoder)
+ if not joint_training:
+ self.encoder = self.encoder.to(torch.bfloat16)
+ no_grad(self.encoder)
+
+ self.joint_training = joint_training
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
+class FlattenDiTScalingEncoder(nn.Module):
+ def __init__(
+ self,
+ encoder:FlattenDiTEncoder,
+ decoder:FlattenDiTDecoder,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ no_grad(self.decoder)
+
+ if self.encoder.weight_path:
+ weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu'))
+ if self.encoder.load_ema:
+ prefix = "ema_denoiser."
+ else:
+ prefix = "denoiser."
+ for k, v in self.encoder.state_dict().items():
+ try:
+ v.copy_(weight["state_dict"][prefix+k])
+ except:
+ print(f"Failed to copy {prefix+k} to denoiser weight")
+
+ if self.decoder.weight_path:
+ weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu'))
+ if self.decoder.load_ema:
+ prefix = "ema_denoiser."
+ else:
+ prefix = "denoiser."
+ for k, v in self.decoder.state_dict().items():
+ if "blocks." in k:
+ blockid = int(k.split("blocks.")[-1][0])
+ k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}")
+ try:
+ v.copy_(weight["state_dict"][prefix+k])
+ except:
+ print(f"Failed to copy {prefix+k} to denoiser weight")
+ self.decoder = decoder.to(torch.bfloat16)
+
+ def forward(self, x, t, y, s=None):
+ if s is None:
+ _, s = self.encoder(x, t, y)
+ x = self.decoder(x, t, y, s)
+ return x, s
+
diff --git a/src/models/denoiser/flatten_condit_mlp_fixt.py b/src/models/denoiser/flatten_condit_mlp_fixt.py
new file mode 100644
index 0000000..40735e4
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_mlp_fixt.py
@@ -0,0 +1,334 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+class FlattenMLPBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = FeedForward(hidden_size, mlp_hidden_dim)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([])
+ for i in range(self.num_cond_blocks):
+ self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups))
+ for i in range(self.num_blocks-self.num_cond_blocks):
+ self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups))
+
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.s_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, None)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
new file mode 100644
index 0000000..bcf3315
--- /dev/null
+++ b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py
@@ -0,0 +1,321 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.s_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.s_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None, mask=None):
+ B, _, H, W = x.shape
+ pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device)
+ s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2)
+ s = self.s_embedder(s)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos_s, mask)
+ s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size)
+ s = torch.permute(s, (0, 3, 1, 2))
+ s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False)
+ s = torch.permute(s, (0, 2, 3, 1))
+ s = s.view(B, -1, self.hidden_size)
+ s = nn.functional.silu(t + s)
+
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos_x, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_dit_fixt.py b/src/models/denoiser/flatten_dit_fixt.py
new file mode 100644
index 0000000..9412d6e
--- /dev/null
+++ b/src/models/denoiser/flatten_dit_fixt.py
@@ -0,0 +1,306 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device, dtype):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device, dtype)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y, masks=None):
+ if masks is None:
+ masks = [None, ]*self.num_blocks
+ if isinstance(masks, torch.Tensor):
+ masks = masks.unbind(0)
+ if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
+ masks = masks + [None]*(self.num_blocks-len(masks))
+
+ B, _, H, W = x.shape
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ x = self.x_embedder(x)
+ pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
+ B, L, C = x.shape
+ t = self.t_embedder(t.view(-1)).view(B, -1, C)
+ y = self.y_embedder(y).view(B, 1, C)
+ condition = nn.functional.silu(t + y)
+ for i, block in enumerate(self.blocks):
+ x = block(x, condition, pos, masks[i])
+ x = self.final_layer(x, condition)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_dit_fixt_xvout.py b/src/models/denoiser/flatten_dit_fixt_xvout.py
new file mode 100644
index 0000000..4df3393
--- /dev/null
+++ b/src/models/denoiser/flatten_dit_fixt_xvout.py
@@ -0,0 +1,311 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return (self.weight * hidden_states).to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device, dtype):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device, dtype)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, t, y, masks=None):
+ if masks is None:
+ masks = [None, ]*self.num_blocks
+ if isinstance(masks, torch.Tensor):
+ masks = masks.unbind(0)
+ if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks:
+ masks = masks + [None]*(self.num_blocks-len(masks))
+
+ B, _, H, W = x.shape
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ x = self.x_embedder(x)
+ pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype)
+ B, L, C = x.shape
+ t = self.t_embedder(t.view(-1)).view(B, -1, C)
+ y = self.y_embedder(y).view(B, 1, C)
+ condition = nn.functional.silu(t + y)
+ for i, block in enumerate(self.blocks):
+ x = block(x, condition, pos, masks[i])
+ x = self.final_layer(x, condition)
+ x0, v = x.chunk(2, dim=-1)
+ x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ if self.training:
+ return v, x0
+ else:
+ return v
\ No newline at end of file
diff --git a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
new file mode 100644
index 0000000..4e570b0
--- /dev/null
+++ b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py
@@ -0,0 +1,308 @@
+import functools
+from typing import Tuple
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from torch.nn.modules.module import T
+
+# from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+from torch.nn.functional import scaled_dot_product_attention
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+class Embed(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class TimestepEmbedder(nn.Module):
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class LabelEmbedder(nn.Module):
+ def __init__(self, num_classes, hidden_size):
+ super().__init__()
+ self.embedding_table = nn.Embedding(num_classes, hidden_size)
+ self.num_classes = num_classes
+
+ def forward(self, labels,):
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 2*hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ return x
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ xq: torch.Tensor,
+ xk: torch.Tensor,
+ freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ freqs_cis = freqs_cis[None, :, None, :]
+ # xq : B N H Hc
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+ return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class RAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ norm_layer: nn.Module = RMSNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k = apply_rotary_emb(q, k, freqs_cis=pos)
+ q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc
+ k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc
+ v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()
+
+ x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+
+class FlattenDiTBlock(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, ):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c, pos, mask=None):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+class FlattenConDiT(nn.Module):
+ def __init__(
+ self,
+ in_channels=4,
+ num_groups=12,
+ hidden_size=1152,
+ num_blocks=18,
+ num_cond_blocks=4,
+ patch_size=2,
+ num_classes=1000,
+ learn_sigma=True,
+ deep_supervision=0,
+ weight_path=None,
+ load_ema=False,
+ ):
+ super().__init__()
+ self.deep_supervision = deep_supervision
+ self.learn_sigma = learn_sigma
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.num_blocks = num_blocks
+ self.num_cond_blocks = num_cond_blocks
+ self.patch_size = patch_size
+ self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True)
+ self.t_embedder = TimestepEmbedder(hidden_size)
+ self.y_embedder = LabelEmbedder(num_classes+1, hidden_size)
+
+ self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2)
+
+ self.weight_path = weight_path
+
+ self.load_ema = load_ema
+ self.blocks = nn.ModuleList([
+ FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks)
+ ])
+ self.initialize_weights()
+ self.precompute_pos = dict()
+
+ def fetch_pos(self, height, width, device):
+ if (height, width) in self.precompute_pos:
+ return self.precompute_pos[(height, width)].to(device)
+ else:
+ pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device)
+ self.precompute_pos[(height, width)] = pos
+ return pos
+
+ def initialize_weights(self):
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
+ w = self.x_embedder.proj.weight.data
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
+
+
+ # Initialize label embedding table:
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
+
+ # Initialize timestep embedding MLP:
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
+
+ # # Zero-out adaLN modulation layers in SiT blocks:
+ # for block in self.blocks:
+ # nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ # nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers:
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+
+ def forward(self, x, t, y, s=None, mask=None):
+ B, _, H, W = x.shape
+ pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
+ x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+ t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
+ y = self.y_embedder(y).view(B, 1, self.hidden_size)
+ c = nn.functional.silu(t + y)
+ if s is None:
+ s = self.x_embedder(x)
+ for i in range(self.num_cond_blocks):
+ s = self.blocks[i](s, c, pos, mask)
+ s = nn.functional.silu(t + s)
+
+ x = self.x_embedder(x)
+ for i in range(self.num_cond_blocks, self.num_blocks):
+ x = self.blocks[i](x, s, pos, None)
+ x = self.final_layer(x, s)
+ x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return x, s
\ No newline at end of file
diff --git a/src/models/denoiser/flowdcn.py b/src/models/denoiser/flowdcn.py
new file mode 100644
index 0000000..92e2237
--- /dev/null
+++ b/src/models/denoiser/flowdcn.py
@@ -0,0 +1,160 @@
+import torch
+import torch.nn as nn
+import math
+
+from torch.nn.init import zeros_
+from src.models.denoiser.base_model import BaseModel
+from src.ops.triton_kernels.function import DCNFunction
+
+def modulate(x, shift, scale):
+ return x * (1 + scale[:, None, None]) + shift[:, None, None]
+
+class PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer = None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ def forward(self, x):
+ x = self.proj(x)
+ x = self.norm(x)
+ return x
+
+class RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
+ def forward(self, x):
+ b, h, w, c = x.shape
+ x = x.view(b, h*w, c)
+ x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+ x = x.view(b, h, w, c)
+ return x
+
+
+class MultiScaleDCN(nn.Module):
+ def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True):
+ super().__init__()
+ self.in_channels = in_channels
+ self.groups = groups
+ self.channels = channels
+ self.kernels = kernels
+ self.v = nn.Linear(in_channels, groups * channels, bias=True)
+ self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True)
+ self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False)
+ self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True)
+ self.out = nn.Linear(groups * channels, in_channels)
+ self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False)
+ self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True)
+ self.max_scale = 6
+ self._init_weights()
+ def _init_weights(self):
+ zeros_(self.qk_deformables.weight.data)
+ zeros_(self.qk_scales.weight.data)
+ zeros_(self.qk_deformables.bias.data)
+ zeros_(self.qk_weights.weight.data)
+ zeros_(self.v.bias.data)
+ zeros_(self.out.bias.data)
+ num_prior = int(self.kernels ** 0.5)
+ dx = torch.linspace(-1, 1, num_prior, device="cuda")
+ dy = torch.linspace(-1, 1, num_prior, device="cuda")
+ dxy = torch.meshgrid([dx, dy], indexing="xy")
+ dxy = torch.stack(dxy, dim=-1)
+ dxy = dxy.view(-1, 2)
+ self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy
+ for i in range(self.groups):
+ scale = (i+1)/self.groups - 0.0001
+ inv_scale = math.log((scale)/(1-scale))
+ self.deformables_scale.data[..., i, :, :] = inv_scale
+ def forward(self, x):
+ B, H, W, _ = x.shape
+ v = self.v(x).view(B, H, W, self.groups, self.channels)
+ deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2)
+ scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale
+ deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale
+ weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels)
+ out = DCNFunction.apply(v, deformables, weights)
+ out = out.view(B, H, W, -1)
+ out = self.out(out)
+ return out
+
+class FlowDCNBlock(nn.Module):
+ def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True):
+ super().__init__()
+ self.norm1 = RMSNorm(hidden_size, eps=1e-6)
+ self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass)
+ self.norm2 = RMSNorm(hidden_size, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
+ x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
+ x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
+ return x
+
+
+
+
+class FlowDCN(BaseModel):
+ def __init__(self, deformable_biass=True, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.blocks = nn.ModuleList([
+ FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks)
+ ])
+ self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True)
+ self.initialize_weights()
+
+ def forward(self, x, t, y):
+ batch_size, _, height, width = x.shape[0]
+ x = self.x_embedder(x) # (N, D, h, w)
+ x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1)
+ t = self.t_embedder(t) # (N, D)
+ y = self.y_embedder(y, self.training) # (N, D)
+ c = t + y # (N, D)
+ B, L, C = x.shape
+ x = x.view(B, height//self.patch_size, width//self.patch_size, C)
+ for block in self.blocks:
+ x = block(x, c) # (N, T, D)
+ x = x.view(B, L, C)
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
+ x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size)
+ if self.learn_sigma:
+ x, _ = torch.split(x, self.out_channels // 2, dim=1)
+ return x
\ No newline at end of file
diff --git a/src/models/encoder.py b/src/models/encoder.py
new file mode 100644
index 0000000..8b7f96a
--- /dev/null
+++ b/src/models/encoder.py
@@ -0,0 +1,132 @@
+import torch
+import copy
+import os
+import timm
+import transformers
+import torch.nn as nn
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
+from torchvision.transforms import Normalize
+
+class RandViT(nn.Module):
+ def __init__(self, model_id, weight_path:str=None):
+ super(RandViT, self).__init__()
+ self.encoder = timm.create_model(
+ model_id,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([0.0
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+class MAE(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(MAE, self).__init__()
+ if os.path.isdir(weight_path):
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
+ self.encoder = timm.create_model(
+ model_id,
+ checkpoint_path=weight_path,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([0.0
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+class DINO(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(DINO, self).__init__()
+ if os.path.isdir(weight_path):
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
+ self.encoder = timm.create_model(
+ model_id,
+ checkpoint_path=weight_path,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([ 0.0,
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([ 1.0,
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+class CLIP(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(CLIP, self).__init__()
+ self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path)
+ self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size
+ self.shifts = nn.Parameter(torch.tensor([0.0,
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0,
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder(x)['last_hidden_state'][:, 1:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+
+
+class DINOv2(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = transformers.Dinov2Model.from_pretrained(weight_path)
+ self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward(x)['last_hidden_state'][:, 1:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ return feature
\ No newline at end of file
diff --git a/src/models/vae.py b/src/models/vae.py
new file mode 100644
index 0000000..c47b087
--- /dev/null
+++ b/src/models/vae.py
@@ -0,0 +1,81 @@
+import torch
+import subprocess
+import lightning.pytorch as pl
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+def class_fn_from_str(class_str):
+ class_module, from_class = class_str.rsplit(".", 1)
+ class_module = __import__(class_module, fromlist=[from_class])
+ return getattr(class_module, from_class)
+
+
+class BaseVAE(torch.nn.Module):
+ def __init__(self, scale=1.0, shift=0.0):
+ super().__init__()
+ self.model = torch.nn.Identity()
+ self.scale = scale
+ self.shift = shift
+
+ def encode(self, x):
+ return x/self.scale+self.shift
+
+ def decode(self, x):
+ return (x-self.shift)*self.scale
+
+
+# very bad bugs with nearest sampling
+class DownSampleVAE(BaseVAE):
+ def __init__(self, down_ratio, scale=1.0, shift=0.0):
+ super().__init__()
+ self.model = torch.nn.Identity()
+ self.scale = scale
+ self.shift = shift
+ self.down_ratio = down_ratio
+
+ def encode(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
+ return x/self.scale+self.shift
+
+ def decode(self, x):
+ x = (x-self.shift)*self.scale
+ x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
+ return x
+
+
+
+class LatentVAE(BaseVAE):
+ def __init__(self, precompute=False, weight_path:str=None):
+ super().__init__()
+ self.precompute = precompute
+ self.model = None
+ self.weight_path = weight_path
+
+ from diffusers.models import AutoencoderKL
+ setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
+ self.scaling_factor = self.model.config.scaling_factor
+
+ @torch.no_grad()
+ def encode(self, x):
+ assert self.model is not None
+ if self.precompute:
+ return x.mul_(self.scaling_factor)
+ return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
+
+ @torch.no_grad()
+ def decode(self, x):
+ assert self.model is not None
+ return self.model.decode(x.div_(self.scaling_factor)).sample
+
+
+def uint82fp(x):
+ x = x.to(torch.float32)
+ x = (x - 127.5) / 127.5
+ return x
+
+def fp2uint8(x):
+ x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
+ return x
+
diff --git a/src/ops/cuda_kernels/backward.cu b/src/ops/cuda_kernels/backward.cu
new file mode 100644
index 0000000..2e85d86
--- /dev/null
+++ b/src/ops/cuda_kernels/backward.cu
@@ -0,0 +1,346 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+namespace cg = cooperative_groups;
+
+template
+__device__ __always_inline int toInt(scalar_t val);
+
+template<>
+__device__ __always_inline int toInt(float val){
+ return static_cast(val);
+}
+template<>
+__device__ __always_inline int toInt(half val){
+ return __half2int_rz(val);
+}
+
+template
+__device__ __always_inline scalar_t fromInt(int val);
+
+template<>
+__device__ __always_inline float fromInt(int val){
+ return static_cast(val);
+}
+
+template<>
+__device__ __always_inline half fromInt(int val){
+ return __int2half_rz(val);
+}
+
+template
+__device__ __always_inline scalar_t constVal(float val);
+
+template<>
+__device__ __always_inline float constVal(float val) {
+ return (float)val;
+}
+
+template<>
+__device__ __always_inline half constVal(float val) {
+ return __float2half(val); // Using float to half conversion
+}
+template<>
+__device__ __always_inline nv_bfloat16 constVal(float val){
+ return __float2bfloat16(val);
+}
+
+
+
+
+
+// B, H, W, C, BLOCK_DIM must be multiple of C
+template
+__global__ void dcn_backward_pipeline_kernel(
+ const int H,
+ const int W,
+ const int G,
+ const int K,
+ const int C,
+ scalar_t* ptr_values,
+ scalar_t* ptr_deformables,
+ scalar_t* ptr_weights,
+ scalar_t* ptr_grad_out,
+ scalar_t* ptr_grad_values,
+ scalar_t* ptr_grad_deformables,
+ scalar_t* ptr_grad_weights
+) {
+ auto block = cg::this_thread_block();
+ auto self_thread = cg::this_thread();
+ auto tile_threads = cg::tiled_partition(block);
+ int local_thread_id = block.thread_rank();
+ int local_tile_id = tile_threads.meta_group_rank();
+ int num_local_tiles = tile_threads.meta_group_size();
+ int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id;
+
+ extern __shared__ int shm[];
+ auto GradBuffer = reinterpret_cast(shm);
+ scalar_t* Buffer = reinterpret_cast(shm) + num_local_tiles*C;
+ if(global_tile_id >= H*W*G) return;
+
+ int bid = block.group_index().y;
+ int gid = global_tile_id % G;
+ int wid = global_tile_id / G % W;
+ int hid = global_tile_id / G / W;
+ int globale_offset = bid*H*W*G*C + global_tile_id*C;
+ cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C);
+
+ int shared_offset[pipeline_stages];
+ for (int s = 0; s < pipeline_stages; ++s) {
+ shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4);
+ }
+
+ auto pipeline = cuda::make_pipeline();
+ const int num_tiles_per_thread = C/TILE_C/TILE_THREADS;
+
+ for(int k=0; k(wid);
+ y = ptr_deformables[offset*2 + 1] + fromInt(hid);
+// x = fromInt(wid);
+// y = fromInt(hid);
+ weight = ptr_weights[offset];
+ }
+ tile_threads.sync();
+ x = tile_threads.shfl(x, 0);
+ y = tile_threads.shfl(y, 0);
+ weight = tile_threads.shfl(weight, 0);
+
+ int floor_x = toInt(x);
+ int floor_y = toInt(y);
+ int ceil_x = floor_x + 1;
+ int ceil_y = floor_y + 1;
+
+
+ scalar_t dodx = constVal(0.0f);
+ scalar_t dody = constVal(0.0f);
+ scalar_t dodw = constVal(0.0f);
+
+ int start_c = tile_threads.thread_rank() * (C / TILE_THREADS);
+
+ bool tl_flag = (floor_x >=0) and (floor_x =0) and (floor_y=0) and (ceil_x =0) and (floor_y=0) and (floor_x =0) and (ceil_y=0) and (ceil_x =0) and (ceil_y(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
+ dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
+ dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j];
+ dodw = dodw + (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
+ dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1];
+ dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1];
+ {
+ vec2_t vtl_di;
+ vtl_di.x = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
+ vtl_di.y = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1];
+ atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di);
+ }
+ }
+
+
+ if(tr_flag){
+ // tr
+ dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
+ dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
+ dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j];
+ dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1];
+ dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1];
+ dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1];
+ {
+ vec2_t vtr_di;
+ vtr_di.x = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j];
+ vtr_di.y = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1];
+ atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di);
+ }
+ }
+
+ if(bl_flag){
+ // bl
+ dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
+ dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
+ dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j];
+ dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
+ dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
+ dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1];
+ {
+ vec2_t vbl_di;
+ vbl_di.x = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j];
+ vbl_di.y = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1];
+ atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di);
+ }
+ }
+
+
+ if(br_flag){
+ // tr
+ dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
+ dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
+ dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j];
+ dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
+ dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
+ dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1];
+ {
+ vec2_t vbr_di;
+ vbr_di.x = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j];
+ vbr_di.y = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1];
+ atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di);
+ }
+ }
+ }
+ pipeline.consumer_release();
+ }
+ for (int i = TILE_THREADS>>1; i > 0; i/=2) {
+ dodx = dodx + tile_threads.shfl_down(dodx, i);
+ dody = dody + tile_threads.shfl_down(dody, i);
+ dodw = dodw + tile_threads.shfl_down(dodw, i);
+ }
+ if (tile_threads.thread_rank() == 0) {
+ cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline);
+ cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline);
+ cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline);
+ }
+ }
+}
+
+
+using namespace torch;
+template
+void backward(const int B,
+ const int H,
+ const int W,
+ const int G,
+ const int K,
+ const int C,
+ torch::Tensor values,
+ torch::Tensor deformables,
+ torch::Tensor weights,
+ torch::Tensor grad_out,
+ torch::Tensor grad_values,
+ torch::Tensor grad_deformables,
+ torch::Tensor grad_weights
+) {
+ int num_local_tiles =(THREADS/TILE_THREADS);
+ int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles;
+ dim3 launch_threads_per_block(THREADS);
+ dim3 launch_blocks(num_global_tiles, B);
+
+ int deformable_shm_size = 0;
+ int grad_out_shm_size = num_local_tiles*C;
+ int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS);
+
+ int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size;
+// printf("shm_size: %d\n", shm_size/512);
+// printf("pipeline_size: %d\n", pipeline_shm_size/512);
+// printf("grad_out_size: %d\n", grad_out_shm_size/512);
+
+
+ switch (values.type().scalarType()) {
+ case at::ScalarType::Half:
+ return dcn_backward_pipeline_kernel<<>>(
+ H, W, G, K, C,
+ reinterpret_cast(values.data_ptr()),
+ reinterpret_cast(deformables.data_ptr()),
+ reinterpret_cast(weights.data_ptr()),
+ reinterpret_cast(grad_out.data_ptr()),
+ reinterpret_cast(grad_values.data_ptr()),
+ reinterpret_cast(grad_deformables.data_ptr()),
+ reinterpret_cast(grad_weights.data_ptr())
+ );
+// case at::ScalarType::BFloat16:
+// return dcn_backward_pipeline_kernel<<>>(
+// H, W, G, K, C,
+// reinterpret_cast(values.data_ptr()),
+// reinterpret_cast(deformables.data_ptr()),
+// reinterpret_cast(weights.data_ptr()),
+// reinterpret_cast(grad_out.data_ptr()),
+// reinterpret_cast(grad_values.data_ptr()),
+// reinterpret_cast(grad_deformables.data_ptr()),
+// reinterpret_cast(grad_weights.data_ptr())
+// );
+ default:
+ printf("running error");
+ }
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, "");
+ m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, "");
+ m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, "");
+ m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, "");
+ m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, "");
+ m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, "");
+ m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, "");
+ m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, "");
+ m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, "");
+ m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, "");
+ m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, "");
+ m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, "");
+ m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, "");
+ m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, "");
+ m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, "");
+// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, "");
+// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, "");
+// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, "");
+
+ m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, "");
+ m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, "");
+ m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, "");
+ m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, "");
+}
diff --git a/src/ops/cuda_kernels/bak_forward.cu b/src/ops/cuda_kernels/bak_forward.cu
new file mode 100644
index 0000000..00569f8
--- /dev/null
+++ b/src/ops/cuda_kernels/bak_forward.cu
@@ -0,0 +1,289 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+template
+__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
+ #pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
+ #pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
+#pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
+#pragma unroll
+ for(int i=0; i
+__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
+ int work_id = threadIdx.x;
+ int bid = blockIdx.z;
+ int gid = blockIdx.y;
+ int G = gridDim.y;
+ int c_blockid = blockIdx.x;
+ int work_load = (H*W/blockDim.x);
+
+ __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
+ // __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
+ math_t register_bufferA[BLOCK_DIM] = {0};
+ int base_c = c_blockid*BLOCK_DIM;
+
+ int num_transfers = BLOCK_DIM;
+#pragma unroll
+ for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
+ // loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
+#pragma unroll
+ for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load top right
+ math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
+ if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
+ }
+
+ // load bottom left
+ math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
+ if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load bottom right
+ math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
+ if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
+ }
+
+ }
+ // loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
+ int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
+#pragma unroll
+ for(int j=0; j
+__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
+ int work_id = threadIdx.x;
+ int bid = blockIdx.z;
+ int gid = blockIdx.y;
+ int G = gridDim.y;
+ int c_blockid = blockIdx.x;
+ int work_load = (H*W/blockDim.x);
+
+ __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM
+ __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM
+ math_t register_bufferA[BLOCK_DIM] = {0};
+ int base_c = c_blockid*BLOCK_DIM;
+
+ int num_transfers = BLOCK_DIM/transfer_length;
+#pragma unroll
+ for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
+ loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
+#pragma unroll
+ for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load top right
+ math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
+ if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
+ }
+
+ // load bottom left
+ math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
+ if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load bottom right
+ math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
+ if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
+ }
+
+ }
+ loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
+
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for(int i=0; i
+void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
+
+ int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
+ dim3 launch_threads_per_block(THREADS);
+ dim3 launch_blocks(NUM_C_BLOCK, G, B);
+
+ switch (value.type().scalarType()) {
+ case at::ScalarType::Half:
+ return dcn_forward_kernel_16<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ case at::ScalarType::BFloat16:
+ return dcn_forward_kernel_16<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ case at::ScalarType::Float:
+ return dcn_forward_kernel<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ default:
+ printf("running error");
+ }
+}
+
+
+// PyBind11 bindings
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
+//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
+}
diff --git a/src/ops/cuda_kernels/forward.cu b/src/ops/cuda_kernels/forward.cu
new file mode 100644
index 0000000..ac18308
--- /dev/null
+++ b/src/ops/cuda_kernels/forward.cu
@@ -0,0 +1,309 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+template
+__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
+ #pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){
+ #pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){
+#pragma unroll
+ for(int i=0; i
+__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){
+#pragma unroll
+ for(int i=0; i
+__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
+ int work_id = threadIdx.x;
+ int bid = blockIdx.z;
+ int gid = blockIdx.y;
+ int G = gridDim.y;
+ int c_blockid = blockIdx.x;
+ int work_load = (H*W/blockDim.x);
+
+ extern __shared__ int shm[];
+ math_t* math_buffer = reinterpret_cast(shm);
+
+ math_t register_bufferA[BLOCK_DIM] = {0};
+ int base_c = c_blockid*BLOCK_DIM;
+
+#pragma unroll
+ for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
+#pragma unroll
+ for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load top right
+ math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
+ if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
+ }
+
+ // load bottom left
+ math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
+ if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load bottom right
+ math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
+ if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
+ }
+
+ }
+ int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c);
+#pragma unroll
+ for(int j=0; j
+__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){
+ int work_id = threadIdx.x;
+ int bid = blockIdx.z;
+ int gid = blockIdx.y;
+ int G = gridDim.y;
+ int c_blockid = blockIdx.x;
+ int work_load = (H*W/blockDim.x);
+
+ extern __shared__ int shm[];
+ math_t* math_buffer = reinterpret_cast(shm);
+ scalar_t* io_buffer = reinterpret_cast(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t);
+ math_t register_bufferA[BLOCK_DIM] = {0};
+ int base_c = c_blockid*BLOCK_DIM;
+
+ int num_transfers = BLOCK_DIM/transfer_length;
+#pragma unroll
+ for(int i=0; i(register_bufferA, 1, BLOCK_DIM);
+ loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM);
+#pragma unroll
+ for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load top right
+ math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight;
+ if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM);
+ }
+
+ // load bottom left
+ math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight;
+ if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM);
+ }
+ // load bottom right
+ math_t br_weight = (x - floor_x)*(y - floor_y)*weight;
+ if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){
+ loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM);
+ }
+
+ }
+ loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM);
+
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for(int i=0; i
+void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
+
+ int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
+ dim3 launch_threads_per_block(THREADS);
+ dim3 launch_blocks(NUM_C_BLOCK, G, B);
+ int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half);
+ switch (value.type().scalarType()) {
+ case at::ScalarType::Half:
+ return dcn_forward_kernel_register<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ case at::ScalarType::Float:
+ return dcn_forward_kernel_register<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ default:
+ printf("running error");
+ }
+}
+
+template
+void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) {
+
+ int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM;
+ dim3 launch_threads_per_block(THREADS);
+ dim3 launch_blocks(NUM_C_BLOCK, G, B);
+ int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half);
+ switch (value.type().scalarType()) {
+ case at::ScalarType::Half:
+ return dcn_forward_kernel_pipeline<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ case at::ScalarType::BFloat16:
+ return dcn_forward_kernel_pipeline<<>>(
+ H, W, C,
+ value.data_ptr(),
+ deformables.data_ptr(),
+ weights.data_ptr(),
+ out.data_ptr());
+ default:
+ printf("running error");
+ }
+}
+
+// PyBind11 bindings
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward");
+//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward");
+m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward");
+m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward");
+m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward");
+m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward");
+// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward");
+}
diff --git a/src/ops/cuda_kernels/forward.py b/src/ops/cuda_kernels/forward.py
new file mode 100644
index 0000000..4ea9c5e
--- /dev/null
+++ b/src/ops/cuda_kernels/forward.py
@@ -0,0 +1,95 @@
+import triton
+import triton.language as tl
+
+@triton.autotune(
+ configs=[
+ # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1),
+ triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
+ triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
+ ],
+ key=['B', 'H', 'W', 'G', 'C', 'K'],
+)
+@triton.jit
+def forward_kernel(
+ B: tl.constexpr,
+ H: tl.constexpr, # image_size_h
+ W: tl.constexpr, # image_size_w
+ G: tl.constexpr, # num_channels_per_group
+ C: tl.constexpr, # num_groups
+ K: tl.constexpr, # kernel size
+ input_ptr, # input features [B, H, W, G, C]
+ deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
+ weights_ptr, # weights [B, H, W, G, K]
+ out_ptr, # out [B, H, W, G, C]
+ BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
+):
+ pid = tl.program_id(0)
+ wid = pid % W
+ hid = pid // W % H
+ gid = pid // (W * H) % G
+ bid = pid // (W * H * G)
+
+ id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
+ common_offset = bid*H*W*G + hid*W*G + wid*G + gid
+ batch_base = bid * H * W * G * C
+
+ for block_base in tl.static_range(0, C, BLOCK_SIZE):
+ buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
+ block_offset = tl.arange(0, BLOCK_SIZE) + block_base
+ block_mask = (block_offset < C) & id_mask
+ for k in tl.static_range(K):
+ deformable_offset = (common_offset * K + k) * 2
+
+ x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
+ y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
+
+ floor_x = x.to(tl.int32)
+ floor_y = y.to(tl.int32)
+ ceil_x = floor_x + 1
+ ceil_y = floor_y + 1
+
+ # load top left
+ tl_weight = (ceil_x - x) * (ceil_y - y)
+ tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
+
+ # load top right
+ tr_weight = (x - floor_x) * (ceil_y - y)
+ tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
+ # load bottom left
+ bl_weight = (ceil_x - x) * (y - floor_y)
+ bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
+ # load bottom right
+ br_weight = (x - floor_x) * (y - floor_y)
+ br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
+
+ # load dynamic weight and mask
+ weights_offset = common_offset*K + k
+ weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
+
+
+
+ tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
+ tl_block_input = tl_block_input * tl_weight
+
+ # load top right
+ tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
+ tr_block_input = tr_block_input * tr_weight
+ # load bottom left
+ bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
+ bl_block_input = bl_block_input * bl_weight
+ # load bottom right
+ br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
+ br_block_input = br_block_input * br_weight
+
+ # sampled
+ sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
+
+ weighted_sampled_input = sampled_input * weight
+ buffer = buffer + weighted_sampled_input
+ # store to out_ptr
+ tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
+
diff --git a/src/ops/cuda_kernels/function.py b/src/ops/cuda_kernels/function.py
new file mode 100644
index 0000000..9d4bfad
--- /dev/null
+++ b/src/ops/cuda_kernels/function.py
@@ -0,0 +1,126 @@
+import time
+import dcn_cuda_backward
+import dcn_cuda_forward
+
+import math
+import torch
+from typing import Any
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.cuda.amp import custom_fwd, custom_bwd
+from .forward import forward_kernel
+
+
+class DCNFunction(Function):
+ BP_FUNCS = [
+ dcn_cuda_backward.backward_p1_c2_tile16_thread128,
+ dcn_cuda_backward.backward_p1_c4_tile16_thread128,
+ dcn_cuda_backward.backward_p2_c2_tile16_thread128,
+ dcn_cuda_backward.backward_p1_c2_tile16_thread256,
+ dcn_cuda_backward.backward_p1_c4_tile16_thread256,
+ dcn_cuda_backward.backward_p2_c2_tile16_thread256,
+ dcn_cuda_backward.backward_p1_c2_tile16_thread384,
+ dcn_cuda_backward.backward_p1_c4_tile16_thread384,
+ dcn_cuda_backward.backward_p2_c2_tile16_thread384,
+ dcn_cuda_backward.backward_p1_c2_tile16_thread512,
+ dcn_cuda_backward.backward_p1_c4_tile16_thread512,
+ dcn_cuda_backward.backward_p2_c2_tile16_thread512,
+ dcn_cuda_backward.backward_p1_c2_tile16_thread768,
+ dcn_cuda_backward.backward_p1_c4_tile16_thread768,
+ dcn_cuda_backward.backward_p2_c2_tile16_thread768,
+ dcn_cuda_backward.backward_p1_c2_tile32_thread128,
+ dcn_cuda_backward.backward_p1_c2_tile32_thread256,
+ dcn_cuda_backward.backward_p1_c2_tile32_thread384,
+ dcn_cuda_backward.backward_p1_c2_tile32_thread512,
+ ]
+ FW_FUNCS = [
+ dcn_cuda_forward.dcn_forward_l256_c4,
+ dcn_cuda_forward.dcn_forward_l256_c8,
+ dcn_cuda_forward.dcn_forward_l256_c16,
+ ]
+ BP_TABLES = dict()
+ FW_TABLES = dict()
+
+
+ @staticmethod
+ @custom_fwd(cast_inputs=torch.float16)
+ def forward(ctx: Any, values, deformables, weights) -> Any:
+ B, H, W, G, C = values.shape
+ func = DCNFunction.find_fw_funcs(values, deformables, weights)
+ out = torch.zeros_like(values)
+ func(B, G, C, H, W, values, deformables, weights, out)
+ return out
+
+ @staticmethod
+ def find_fw_funcs(values, deformables, weights):
+ B, H, W, G, C = values.shape
+ B, H, W, G, K = weights.shape
+ hash_value = 10000 * B + 100 * H + W + 1000 * G
+ if hash_value in DCNFunction.FW_TABLES.keys():
+ return DCNFunction.FW_TABLES[hash_value]
+ print("missing")
+ candicate_func = None
+ min_t = 999.0
+ outs = torch.zeros_like(values)
+ for func in DCNFunction.FW_FUNCS:
+ t = []
+ for i in range(100):
+ torch.cuda.synchronize()
+ start_t = time.time()
+ func(B, G, C, H, W, values, deformables, weights, outs)
+ torch.cuda.synchronize()
+ t.append(time.time() - start_t)
+ t = t[-50:]
+ t = sum(t) / len(t)
+ if t < min_t:
+ min_t = t
+ DCNFunction.FW_TABLES[hash_value] = func
+ candicate_func = func
+ assert candicate_func is not None
+ print(candicate_func)
+ return candicate_func
+ @staticmethod
+ def find_bp_funcs(values, deformables, weights, grad_out):
+ B, H, W, G, C = values.shape
+ B, H, W, G, K = weights.shape
+ hash_value = 10000 * B + 100 * H + W + 1000 * G
+ if hash_value in DCNFunction.BP_TABLES.keys():
+ return DCNFunction.BP_TABLES[hash_value]
+ print("missing")
+ candicate_func = None
+ min_t = 999.0
+ grad_values = torch.zeros_like(values)
+ grad_deformables = torch.zeros_like(deformables)
+ grad_weights = torch.zeros_like(weights)
+ for func in DCNFunction.BP_FUNCS:
+ t = []
+ for i in range(100):
+ torch.cuda.synchronize()
+ start_t = time.time()
+ func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
+ torch.cuda.synchronize()
+ t.append(time.time() - start_t)
+ t = t[-50:]
+ t = sum(t) / len(t)
+ if t < min_t:
+ min_t = t
+ DCNFunction.BP_TABLES[hash_value] = func
+ candicate_func = func
+ assert candicate_func is not None
+ print(candicate_func)
+ return candicate_func
+
+ @staticmethod
+ @once_differentiable
+ @custom_bwd
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
+ grad_out = grad_outputs[0]
+ values, deformables, weights = ctx.saved_tensors
+ B, H, W, G, C = values.shape
+ B, H, W, G, K = weights.shape
+ func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out)
+ grad_values = torch.zeros_like(values)
+ grad_deformables = torch.zeros_like(deformables)
+ grad_weights = torch.zeros_like(weights)
+ func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights)
+ return grad_values, grad_deformables, grad_weights
\ No newline at end of file
diff --git a/src/ops/cuda_kernels/setup.py b/src/ops/cuda_kernels/setup.py
new file mode 100644
index 0000000..34079d4
--- /dev/null
+++ b/src/ops/cuda_kernels/setup.py
@@ -0,0 +1,59 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='dcn_cuda_forward',
+ ext_modules=[
+ CUDAExtension('dcn_cuda_forward', ['./forward.cu',],
+ extra_compile_args={'cxx': [], 'nvcc': [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ "--use_fast_math",
+ "-O3",
+ ]}
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ }
+)
+
+setup(
+ name='dcn_cuda_backward',
+ ext_modules=[
+ CUDAExtension('dcn_cuda_backward', ['./backward.cu',],
+ extra_compile_args={'cxx': [], 'nvcc': [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ "--use_fast_math",
+ "-O3",
+ ]}
+ ),
+ ],
+ cmdclass={
+ 'build_ext': BuildExtension
+ }
+)
+
+
+# setup(
+# name='mycuda',
+# ext_modules=[
+# CUDAExtension('mycuda', ['./backward.cu',],
+# extra_compile_args={'cxx': [], 'nvcc': [
+# "-O3",
+# "-DCUDA_HAS_FP16=1",
+# "-D__CUDA_NO_HALF_OPERATORS__",
+# "-D__CUDA_NO_HALF_CONVERSIONS__",
+# "-D__CUDA_NO_HALF2_OPERATORS__",
+# ]}
+# ),
+# ],
+# cmdclass={
+# 'build_ext': BuildExtension
+# }
+# )
\ No newline at end of file
diff --git a/src/ops/triton_kernels/__init__.py b/src/ops/triton_kernels/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/ops/triton_kernels/backward.py b/src/ops/triton_kernels/backward.py
new file mode 100644
index 0000000..e886aa2
--- /dev/null
+++ b/src/ops/triton_kernels/backward.py
@@ -0,0 +1,124 @@
+import triton
+import triton.language as tl
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1),
+ triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
+ ],
+ key=['B', 'H', 'W', 'G', 'C', 'K'],
+)
+@triton.jit
+def backward_kernel(
+ B: tl.constexpr,
+ H: tl.constexpr, # image_size_h
+ W: tl.constexpr, # image_size_w
+ G: tl.constexpr, # num_groups
+ C: tl.constexpr, # num_channels_per_group
+ K: tl.constexpr, # kernel size
+ input_ptr, # input features [B, H, W, G, C]
+ deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
+ weights_ptr, # weights [B, H, W, G, K]
+ grad_ptr, # out [B, H, W, G, C]
+ grad_input_ptr, # input features [B, H, W, G, C]
+ grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2]
+ grad_weights_ptr, # weights [B, H, W, G, K]
+ BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
+):
+
+ pid = tl.program_id(0)
+ wid = pid % W
+ hid = pid // W % H
+ gid = pid // (W * H) % G
+ bid = pid // (W * H * G)
+
+ id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
+
+ common_offset = bid*H*W*G + hid*W*G + wid*G + gid
+ batch_base = bid * H * W * G * C
+ for k in tl.static_range(K):
+ # load dynamic weight and mask
+ weights_offset = common_offset*K + k
+ weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
+ dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
+ dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty)
+ dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty)
+ deformable_offset = (common_offset * K + k)*2
+ x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
+ y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
+ for block_base in tl.static_range(0, C, BLOCK_SIZE):
+ block_offset = tl.arange(0, BLOCK_SIZE) + block_base
+ block_mask = (block_offset < C) & id_mask
+ grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0)
+ dods = weight*grad
+
+ floor_x = x.to(tl.int32)
+ floor_y = y.to(tl.int32)
+ ceil_x = floor_x + 1
+ ceil_y = floor_y + 1
+
+ # load top left
+ tl_weight = (ceil_x - x) * (ceil_y - y)
+ tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset
+ tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H))
+ tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0)
+ tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0)
+ dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y)
+ dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x)
+ dodw = dodw + tl_block_input_dot_grad * tl_weight
+
+ dodtl = dods * tl_weight
+ tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl)
+
+
+ # load top right
+ tr_weight = (x - floor_x) * (ceil_y - y)
+ tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
+ tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0))
+ tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0)
+ tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0)
+ dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y)
+ dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x)
+ dodw = dodw + tr_block_input_dot_grad*tr_weight
+
+ dodtr = dods * tr_weight
+ tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr)
+
+
+ # load bottom left
+ bl_weight = (ceil_x - x) * (y - floor_y)
+ bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset
+ bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0))
+ bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0)
+ bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0)
+ dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y)
+ dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x)
+ dodw = dodw + bl_block_input_dot_grad*bl_weight
+
+ dodbl = dods * bl_weight
+ tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl)
+
+
+ # load bottom right
+ br_weight = (x - floor_x) * (y - floor_y)
+ br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset
+ br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0))
+ br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0)
+ br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask
+
+ dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y)
+ dody = dody + 1 * br_block_input_dot_grad * (x - floor_x)
+ dodw = dodw + br_block_input_dot_grad*br_weight
+
+ dodbr = dods * br_weight
+ tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr)
+ dodx = dodx * weight
+ dody = dody * weight
+ tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask)
+ tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask)
+ tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask)
+
+
+
+
+
diff --git a/src/ops/triton_kernels/forward.py b/src/ops/triton_kernels/forward.py
new file mode 100644
index 0000000..cf7c243
--- /dev/null
+++ b/src/ops/triton_kernels/forward.py
@@ -0,0 +1,94 @@
+import triton
+import triton.language as tl
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2),
+ # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1),
+ ],
+ key=['B', 'H', 'W', 'G', 'C', 'K'],
+)
+@triton.jit
+def forward_kernel(
+ B: tl.constexpr,
+ H: tl.constexpr, # image_size_h
+ W: tl.constexpr, # image_size_w
+ G: tl.constexpr, # num_channels_per_group
+ C: tl.constexpr, # num_groups
+ K: tl.constexpr, # kernel size
+ input_ptr, # input features [B, H, W, G, C]
+ deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K]
+ weights_ptr, # weights [B, H, W, G, K]
+ out_ptr, # out [B, H, W, G, C]
+ BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group
+):
+ pid = tl.program_id(0)
+ wid = pid % W
+ hid = pid // W % H
+ gid = pid // (W * H) % G
+ bid = pid // (W * H * G)
+
+ id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B)
+ common_offset = bid*H*W*G + hid*W*G + wid*G + gid
+ batch_base = bid * H * W * G * C
+
+ for block_base in tl.static_range(0, C, BLOCK_SIZE):
+ buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
+ block_offset = tl.arange(0, BLOCK_SIZE) + block_base
+ block_mask = (block_offset < C) & id_mask
+ for k in tl.static_range(K):
+ deformable_offset = (common_offset * K + k) * 2
+
+ x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid
+ y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid
+
+ floor_x = x.to(tl.int32)
+ floor_y = y.to(tl.int32)
+ ceil_x = floor_x + 1
+ ceil_y = floor_y + 1
+
+ # load top left
+ tl_weight = (ceil_x - x) * (ceil_y - y)
+ tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)
+
+ # load top right
+ tr_weight = (x - floor_x) * (ceil_y - y)
+ tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)
+ # load bottom left
+ bl_weight = (ceil_x - x) * (y - floor_y)
+ bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)
+ # load bottom right
+ br_weight = (x - floor_x) * (y - floor_y)
+ br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE
+ br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)
+
+ # load dynamic weight and mask
+ weights_offset = common_offset*K + k
+ weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0)
+
+
+
+ tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0)
+ tl_block_input = tl_block_input * tl_weight
+
+ # load top right
+ tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0)
+ tr_block_input = tr_block_input * tr_weight
+ # load bottom left
+ bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0)
+ bl_block_input = bl_block_input * bl_weight
+ # load bottom right
+ br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0)
+ br_block_input = br_block_input * br_weight
+
+ # sampled
+ sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input
+
+ weighted_sampled_input = sampled_input * weight
+ buffer = buffer + weighted_sampled_input
+ # store to out_ptr
+ tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask)
+
diff --git a/src/ops/triton_kernels/function.py b/src/ops/triton_kernels/function.py
new file mode 100644
index 0000000..84987a1
--- /dev/null
+++ b/src/ops/triton_kernels/function.py
@@ -0,0 +1,48 @@
+import torch
+import triton
+from typing import Any
+from torch.autograd import Function
+from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd
+from .forward import forward_kernel
+from .backward import backward_kernel
+
+
+
+class DCNFunction(Function):
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx: Any, inputs, deformables, weights) -> Any:
+ B, H, W, G, C = inputs.shape
+ _, _, _, _, K, _ = deformables.shape
+ out = torch.zeros_like(inputs)
+ grid = lambda META: (B * H * W * G,)
+
+ forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out)
+ ctx.save_for_backward(inputs, deformables, weights)
+ return out
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
+ grad_output = grad_outputs[0].contiguous()
+
+ inputs, deformables, weights = ctx.saved_tensors
+ B, H, W, G, C = inputs.shape
+ _, _, _, _, K, _ = deformables.shape
+
+ grad_inputs = torch.zeros_like(inputs)
+ grad_deformables = torch.zeros_like(deformables)
+ grad_weights = torch.zeros_like(weights)
+ grid = lambda META: (B * H * W * G,)
+ backward_kernel[grid](
+ B, H, W, G, C, K,
+ inputs,
+ deformables,
+ weights,
+ grad_output,
+ grad_inputs,
+ grad_deformables,
+ grad_weights,
+ )
+ return (grad_inputs, grad_deformables, grad_weights)
\ No newline at end of file
diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/plugins/bd_env.py b/src/plugins/bd_env.py
new file mode 100644
index 0000000..c1900e9
--- /dev/null
+++ b/src/plugins/bd_env.py
@@ -0,0 +1,70 @@
+import torch
+import os
+import socket
+from typing_extensions import override
+from lightning.fabric.utilities.rank_zero import rank_zero_only
+from lightning.fabric.plugins.environments.lightning import LightningEnvironment
+
+
+class BDEnvironment(LightningEnvironment):
+ pass
+ # def __init__(self) -> None:
+ # super().__init__()
+ # self._global_rank: int = 0
+ # self._world_size: int = 1
+ #
+ # @property
+ # @override
+ # def creates_processes_externally(self) -> bool:
+ # """Returns whether the cluster creates the processes or not.
+ #
+ # If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the
+ # process launcher/job scheduler and Lightning will not launch new processes.
+ #
+ # """
+ # return "LOCAL_RANK" in os.environ
+ #
+ # @staticmethod
+ # @override
+ # def detect() -> bool:
+ # assert "ARNOLD_WORKER_0_HOST" in os.environ.keys()
+ # assert "ARNOLD_WORKER_0_PORT" in os.environ.keys()
+ # return True
+ #
+ # @override
+ # def world_size(self) -> int:
+ # return self._world_size
+ #
+ # @override
+ # def set_world_size(self, size: int) -> None:
+ # self._world_size = size
+ #
+ # @override
+ # def global_rank(self) -> int:
+ # return self._global_rank
+ #
+ # @override
+ # def set_global_rank(self, rank: int) -> None:
+ # self._global_rank = rank
+ # rank_zero_only.rank = rank
+ #
+ # @override
+ # def local_rank(self) -> int:
+ # return int(os.environ.get("LOCAL_RANK", 0))
+ #
+ # @override
+ # def node_rank(self) -> int:
+ # return int(os.environ.get("ARNOLD_ID"))
+ #
+ # @override
+ # def teardown(self) -> None:
+ # if "WORLD_SIZE" in os.environ:
+ # del os.environ["WORLD_SIZE"]
+ #
+ # @property
+ # def main_address(self) -> str:
+ # return os.environ.get("ARNOLD_WORKER_0_HOST")
+ #
+ # @property
+ # def main_port(self) -> int:
+ # return int(os.environ.get("ARNOLD_WORKER_0_PORT"))
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/utils/copy.py b/src/utils/copy.py
new file mode 100644
index 0000000..62cd89d
--- /dev/null
+++ b/src/utils/copy.py
@@ -0,0 +1,13 @@
+import torch
+
+@torch.no_grad()
+def copy_params(src_model, dst_model):
+ for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
+ dst_param.data.copy_(src_param.data)
+
+@torch.no_grad()
+def swap_tensors(tensor1, tensor2):
+ tmp = torch.empty_like(tensor1)
+ tmp.copy_(tensor1)
+ tensor1.copy_(tensor2)
+ tensor2.copy_(tmp)
\ No newline at end of file
diff --git a/src/utils/model_loader.py b/src/utils/model_loader.py
new file mode 100644
index 0000000..7d99166
--- /dev/null
+++ b/src/utils/model_loader.py
@@ -0,0 +1,29 @@
+from typing import Dict, Any, Optional
+
+import torch
+import torch.nn as nn
+from lightning.fabric.utilities.types import _PATH
+
+
+import logging
+logger = logging.getLogger(__name__)
+
+class ModelLoader:
+ def __init__(self,):
+ super().__init__()
+
+ def load(self, denoiser, prefix=""):
+ if denoiser.weight_path:
+ weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
+
+ if denoiser.load_ema:
+ prefix = "ema_denoiser." + prefix
+ else:
+ prefix = "denoiser." + prefix
+
+ for k, v in denoiser.state_dict().items():
+ try:
+ v.copy_(weight["state_dict"][prefix+k])
+ except:
+ logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
+ return denoiser
\ No newline at end of file
diff --git a/src/utils/no_grad.py b/src/utils/no_grad.py
new file mode 100644
index 0000000..2fd71de
--- /dev/null
+++ b/src/utils/no_grad.py
@@ -0,0 +1,16 @@
+import torch
+
+@torch.no_grad()
+def no_grad(net):
+ for param in net.parameters():
+ param.requires_grad = False
+ net.eval()
+ return net
+
+@torch.no_grad()
+def filter_nograd_tensors(params_list):
+ filtered_params_list = []
+ for param in params_list:
+ if param.requires_grad:
+ filtered_params_list.append(param)
+ return filtered_params_list
\ No newline at end of file
diff --git a/src/utils/patch_bugs.py b/src/utils/patch_bugs.py
new file mode 100644
index 0000000..db9a174
--- /dev/null
+++ b/src/utils/patch_bugs.py
@@ -0,0 +1,17 @@
+import torch
+import lightning.pytorch.loggers.wandb as wandb
+
+setattr(wandb, '_WANDB_AVAILABLE', True)
+torch.set_float32_matmul_precision('medium')
+
+import logging
+logger = logging.getLogger("wandb")
+logger.setLevel(logging.WARNING)
+
+import os
+os.environ["NCCL_DEBUG"] = "WARN"
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+import warnings
+warnings.simplefilter(action='ignore', category=FutureWarning)
+warnings.simplefilter(action='ignore', category=UserWarning)
diff --git a/tools/cache_imlatent3.py b/tools/cache_imlatent3.py
new file mode 100644
index 0000000..640cdb0
--- /dev/null
+++ b/tools/cache_imlatent3.py
@@ -0,0 +1,117 @@
+from diffusers import AutoencoderKL
+
+import torch
+from typing import Callable
+from torchvision.datasets import ImageFolder, ImageNet
+
+import cv2
+import os
+import torch
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+
+from PIL import Image
+import pathlib
+
+import torch
+import random
+from torchvision.io.image import read_image
+import torchvision.transforms as tvtf
+from torch.utils.data import Dataset
+from torchvision.datasets import ImageNet
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+if __name__ == "__main__":
+ writer_pool = ThreadPoolExecutor(8)
+ for split in ['train']:
+ train = split == 'train'
+ transforms = tvtf.Compose([
+ CenterCrop(256),
+ # tvtf.RandomHorizontalFlip(p=1),
+ tvtf.ToTensor(),
+ tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ # dataset = ImageNet(root='/tmp', split="train", transform=transforms, )
+ B = 256
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16)
+ vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
+
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+ vae, dataloader = accelerator.prepare(vae, dataloader)
+ rank = accelerator.process_index
+ with torch.no_grad():
+ for i, (image, label, path_list) in enumerate(dataloader):
+ # if i >= 128: break
+ new_path_list = []
+ for p in path_list:
+ p = p + ".pt"
+ p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
+ "/mnt/bn/wangshuai6/data/ImageNet/train_256latent")
+ new_path_list.append(p)
+
+ image = image.to("cuda")
+ distribution = vae.module.encode(image).latent_dist
+ mean = distribution.mean
+ logvar = distribution.logvar
+ for j in range(B):
+ out = dict(
+ mean=mean[j].cpu(),
+ logvar=logvar[j].cpu(),
+ )
+ writer_pool.submit(save, out, new_path_list[j])
+ writer_pool.shutdown(wait=True)
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/cache_imlatent4.py b/tools/cache_imlatent4.py
new file mode 100644
index 0000000..fc33fa7
--- /dev/null
+++ b/tools/cache_imlatent4.py
@@ -0,0 +1,123 @@
+from diffusers import AutoencoderKL
+
+import torch
+from typing import Callable
+from torchvision.datasets import ImageFolder, ImageNet
+
+import cv2
+import os
+import torch
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+
+from PIL import Image
+import pathlib
+
+import torch
+import random
+from torchvision.io.image import read_image
+import torchvision.transforms as tvtf
+from torch.utils.data import Dataset
+from torchvision.datasets import ImageNet
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+if __name__ == "__main__":
+ writer_pool = ThreadPoolExecutor(8)
+ for split in ['train']:
+ train = split == 'train'
+ transforms = tvtf.Compose([
+ CenterCrop(512),
+ # tvtf.RandomHorizontalFlip(p=1),
+ tvtf.ToTensor(),
+ tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ B = 8
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16)
+ vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
+ vae = vae.to(torch.float16)
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+ vae, dataloader = accelerator.prepare(vae, dataloader)
+ rank = accelerator.process_index
+ with torch.no_grad():
+ for i, (image, label, path_list) in enumerate(dataloader):
+ print(i/len(dataloader))
+ flag = False
+ new_path_list = []
+ for p in path_list:
+ p = p + ".pt"
+ p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
+ "/mnt/bn/wangshuai6/data/ImageNet/train_512_latent")
+ new_path_list.append(p)
+ if not os.path.exists(p):
+ print(p)
+ flag = True
+
+ if flag:
+ image = image.to("cuda")
+ image = image.to(torch.float16)
+ distribution = vae.module.encode(image).latent_dist
+ mean = distribution.mean
+ logvar = distribution.logvar
+
+ for j in range(len(path_list)):
+ out = dict(
+ mean=mean[j].cpu(),
+ logvar=logvar[j].cpu(),
+ )
+ writer_pool.submit(save, out, new_path_list[j])
+ writer_pool.shutdown(wait=True)
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/cat_images.py b/tools/cat_images.py
new file mode 100644
index 0000000..21734b0
--- /dev/null
+++ b/tools/cat_images.py
@@ -0,0 +1,43 @@
+import cv2
+import numpy as np
+import os
+import pathlib
+import argparse
+
+def group_images(path_list):
+ sorted(path_list)
+ class_id_dict = {}
+ for path in path_list:
+ class_id = str(path.name).split('_')[0]
+ if class_id not in class_id_dict:
+ class_id_dict[class_id] = []
+ class_id_dict[class_id].append(path)
+ return class_id_dict
+
+def cat_images(path_list):
+ imgs = []
+ for path in path_list:
+ img = cv2.imread(str(path))
+ os.remove(path)
+ imgs.append(img)
+ row_cat_images = []
+ row_length = int(len(imgs)**0.5)
+ for i in range(len(imgs)//row_length):
+ row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
+ cat_image = np.concatenate(row_cat_images, axis=0)
+ return cat_image
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--src_dir', type=str, default=None)
+
+ args = parser.parse_args()
+ src_dir = args.src_dir
+ path_list = list(pathlib.Path(src_dir).glob('*.png'))
+ class_id_dict = group_images(path_list)
+ for class_id, path_list in class_id_dict.items():
+ cat_image = cat_images(path_list)
+ cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
+ # cat_path = "cat_{}.png".format(class_id)
+ cv2.imwrite(cat_path, cat_image)
+
diff --git a/tools/classifer_training.py b/tools/classifer_training.py
new file mode 100644
index 0000000..b00b435
--- /dev/null
+++ b/tools/classifer_training.py
@@ -0,0 +1,353 @@
+import torch
+import torch.nn as nn
+import timm
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+import copy
+
+# NORMALIZE_DATA = dict(
+# dinov2_vits14a = dict(
+# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79,
+# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49,
+# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67,
+# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12,
+# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43,
+# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61,
+# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23,
+# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57,
+# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65,
+# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13,
+# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27,
+# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94,
+# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07,
+# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56,
+# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50,
+# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29,
+# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78,
+# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03,
+# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02,
+# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25,
+# ],
+# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92,
+# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90,
+# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03,
+# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21,
+# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04,
+# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68,
+# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71,
+# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88,
+# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86,
+# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44,
+# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15,
+# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49,
+# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90,
+# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87,
+# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81,
+# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89,
+# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97,
+# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96,
+# 2.62,3.23,2.13,2.29,2.24,2.72
+# ]
+# ),
+# dinov2_vitb14a = dict(
+# mean=[
+# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82,
+# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17,
+# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31,
+# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15,
+# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03,
+# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54,
+# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43,
+# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26
+# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05,
+# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19,
+# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36,
+# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00,
+# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14,
+# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44,
+# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24,
+# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25,
+# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12,
+# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36,
+# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10,
+# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22,
+# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28,
+# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23,
+# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62,
+# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26,
+# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12,
+# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07,
+# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57,
+# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62,
+# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10,
+# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51,
+# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22,
+# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36,
+# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10,
+# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02,
+# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42,
+# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12,
+# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60,
+# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35,
+# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64,
+# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31
+# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00,
+# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73,
+# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05,
+# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35,
+# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05,
+# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54,
+# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03,
+# ],
+# std=[
+# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40,
+# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39,
+# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56,
+# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73,
+# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49,
+# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42,
+# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47,
+# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49,
+# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05,
+# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46,
+# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67,
+# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91,
+# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48,
+# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76,
+# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37,
+# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52,
+# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42,
+# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56,
+# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43,
+# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47,
+# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54,
+# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49,
+# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59,
+# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52,
+# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55,
+# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55,
+# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38,
+# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38,
+# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53,
+# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49,
+# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41,
+# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43,
+# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68,
+# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47,
+# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43,
+# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54,
+# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44,
+# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69,
+# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50,
+# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33,
+# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58,
+# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47,
+# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29
+# ]
+# )
+# )
+
+class DINOv2a(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2a, self).__init__()
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+ # self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False)
+ # self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False)
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ # feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1)
+ feature = feature.transpose(1, 2)
+ feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2)
+ return feature
+
+
+
+from torchvision.datasets import ImageFolder, ImageNet
+
+import os
+import numpy as np
+
+from PIL import Image
+import torch
+import torchvision.transforms as tvtf
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+import math
+class TimestepEmbedder(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10):
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
+ )
+ args = t[..., None].float() * freqs[None, ...]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+class Classifer(nn.Module):
+ def __init__(self, in_channels=192, hidden_size=256, num_classes=1000):
+ super(Classifer, self).__init__()
+ self.in_channels = in_channels
+ self.feature_x = nn.Sequential(
+ nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0),
+ nn.AdaptiveAvgPool2d(1),
+ )
+ def forward(self, xt):
+ xt = xt[:, :self.in_channels]
+ score = self.feature_x(xt).squeeze(-1).squeeze(-1)
+ # score = (feature_xt).clamp(-5, 5)
+ score = torch.softmax(score, dim=1)
+ return score
+
+
+
+if __name__ == "__main__":
+ torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
+ os.environ["TORCH_HOME"] = torch_hub_dir
+ torch.hub.set_dir(torch_hub_dir)
+
+ transforms = tvtf.Compose([
+ CenterCrop(256),
+ tvtf.ToTensor(),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ B = 64
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True)
+ dino = DINOv2a("dinov2_vitb14")
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+
+ rank = accelerator.process_index
+
+ classifer = Classifer(in_channels=32)
+ classifer.train()
+ optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001)
+
+ dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer)
+
+ # fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance"
+ # fake_file_names = os.listdir(fake_file_dir)
+
+ for epoch in range(100):
+ for i, (true_images, true_labels, path_list) in enumerate(dataloader):
+ batch_size = true_images.shape[0]
+ true_labels = true_labels.to(accelerator.device)
+ true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000)
+ with torch.no_grad():
+ true_dino_feature = dino(true_images)
+ # t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device)
+ # true_x_t = t * true_dino_feature + (1-t) * noise
+
+ true_x_t = true_dino_feature
+ true_score = classifer(true_x_t)
+
+ # ind = i % len(fake_file_names)
+ # fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind]))
+ # import pdb; pdb.set_trace()
+ # ind = torch.randint(0, 50, size=(4,))
+ # fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :]
+ # fake_labels = fake_file['condition'].repeat(4)
+ # fake_score = classifer(fake_x_t)
+
+ loss_true = -torch.log(true_score)*true_labels
+ loss = loss_true.sum()/batch_size
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+
+ acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size
+ if accelerator.is_main_process:
+ print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item()))
+ if accelerator.is_main_process:
+ torch.save(classifer.state_dict(), f'{epoch}.pth')
+
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/debug_env.sh b/tools/debug_env.sh
new file mode 100644
index 0000000..d29dc46
--- /dev/null
+++ b/tools/debug_env.sh
@@ -0,0 +1,4 @@
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat
+pip3 install -r requirements.txt
+git branch --set-upstream-to=origin/master master
+git pull
\ No newline at end of file
diff --git a/tools/dino_scale.py b/tools/dino_scale.py
new file mode 100644
index 0000000..439caaf
--- /dev/null
+++ b/tools/dino_scale.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+import timm
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+import copy
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ return feature
+
+class MAE(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(MAE, self).__init__()
+ if os.path.isdir(weight_path):
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
+ self.encoder = timm.create_model(
+ model_id,
+ checkpoint_path=weight_path,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([0.0
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+
+from diffusers import AutoencoderKL
+
+from torchvision.datasets import ImageFolder, ImageNet
+
+import cv2
+import os
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+
+from PIL import Image
+import torch
+import torchvision.transforms as tvtf
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+if __name__ == "__main__":
+ torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
+ os.environ["TORCH_HOME"] = torch_hub_dir
+ torch.hub.set_dir(torch_hub_dir)
+
+ for split in ['train']:
+ train = split == 'train'
+ transforms = tvtf.Compose([
+ CenterCrop(256),
+ tvtf.ToTensor(),
+ # tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ B = 4096
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
+ dino = DINOv2("dinov2_vitb14")
+ # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
+ # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+ dino, dataloader = accelerator.prepare(dino, dataloader)
+ rank = accelerator.process_index
+
+ acc_mean = torch.zeros((768, ), device=accelerator.device)
+ acc_num = 0
+ with torch.no_grad():
+ for i, (images, labels, path_list) in enumerate(dataloader):
+ acc_num += len(images)
+ feature = dino(images)
+ stds = torch.std(feature, dim=[0, 2, 3]).tolist()
+ for std in stds:
+ print("{:.2f},".format(std), end='')
+ print()
+ means = torch.mean(feature, dim=[0, 2, 3]).tolist()
+ for mean in means:
+ print("{:.2f},".format(mean), end='')
+ break
+
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/dino_scale2.py b/tools/dino_scale2.py
new file mode 100644
index 0000000..336c196
--- /dev/null
+++ b/tools/dino_scale2.py
@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import timm
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+import copy
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ return feature
+
+class MAE(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(MAE, self).__init__()
+ if os.path.isdir(weight_path):
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
+ self.encoder = timm.create_model(
+ model_id,
+ checkpoint_path=weight_path,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([0.0
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+
+from diffusers import AutoencoderKL
+
+from torchvision.datasets import ImageFolder, ImageNet
+
+import cv2
+import os
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+
+from PIL import Image
+import torch
+import torchvision.transforms as tvtf
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+if __name__ == "__main__":
+ torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
+ os.environ["TORCH_HOME"] = torch_hub_dir
+ torch.hub.set_dir(torch_hub_dir)
+
+ for split in ['train']:
+ train = split == 'train'
+ transforms = tvtf.Compose([
+ CenterCrop(256),
+ tvtf.ToTensor(),
+ tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ B = 2048
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
+ dino = DINOv2("dinov2_vitb14")
+ # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
+ # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+ dino, dataloader = accelerator.prepare(dino, dataloader)
+ rank = accelerator.process_index
+
+ with torch.no_grad():
+ for i, (images, labels, path_list) in enumerate(dataloader):
+ feature = dino(images)
+ b, c, h, w = feature.shape
+ feature = feature.view(b, c, h*w).transpose(1, 2)
+ feature = feature.reshape(-1, c)
+ U, S, V = torch.pca_lowrank(feature, 64, )
+ import pdb; pdb.set_trace()
+ feature = torch.matmul(feature, V)
+ break
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/dp.py b/tools/dp.py
new file mode 100644
index 0000000..9f89d75
--- /dev/null
+++ b/tools/dp.py
@@ -0,0 +1,64 @@
+import matplotlib.pyplot as plt
+print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250]))
+print(len(list(range(0, 251, 5))))
+exit()
+plt.plot()
+plt.plot()
+plt.show()
+exit()
+
+
+
+
+import torch
+
+num_steps = 10
+num_recompute_timesteps = 4
+sim = torch.randint(0, 100, (num_steps, num_steps))
+sim[:5, :5] = 100
+for i in range(num_steps):
+ sim[i, i] = 100
+
+error_map = (100-sim).tolist()
+
+
+# init
+for i in range(1, num_steps):
+ for j in range(0, i):
+ error_map[i][j] = error_map[i-1][j] + error_map[i][j]
+
+C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
+P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
+
+for i in range(1, num_steps+1):
+ C[1][i] = error_map[i-1][0]
+ P[1][i] = 0
+
+
+# dp
+for step in range(2, num_recompute_timesteps+1):
+ for i in range(step, num_steps+1):
+ min_value = 99999
+ min_index = -1
+ for j in range(step-1, i):
+ value = C[step-1][j] + error_map[i-1][j]
+ if value < min_value:
+ min_value = value
+ min_index = j
+ C[step][i] = min_value
+ P[step][i] = min_index
+
+# trace back
+tracback_end_index = num_steps
+# min_value = 99999
+# for i in range(num_recompute_timesteps-1, num_steps):
+# if C[-1][i] < min_value:
+# min_value = C[-1][i]
+# tracback_end_index = i
+
+timesteps = [tracback_end_index, ]
+for i in range(num_recompute_timesteps, 0, -1):
+ idx = timesteps[-1]
+ timesteps.append(P[i][idx])
+timesteps.reverse()
+print(timesteps)
\ No newline at end of file
diff --git a/tools/figures/base++.py b/tools/figures/base++.py
new file mode 100644
index 0000000..9715321
--- /dev/null
+++ b/tools/figures/base++.py
@@ -0,0 +1,64 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+is_data = {
+ "4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
+ "6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
+ "8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
+ "10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
+}
+
+fid_data = {
+ "4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
+ "6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
+ "8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
+ "10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
+}
+
+sfid_data = {
+ "4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
+ "6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
+ "8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
+ "10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
+}
+
+pr_data = {
+ "4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
+ "6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
+ "8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
+ "10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
+}
+
+recall_data = {
+ "4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
+ "6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
+ "8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
+ "10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
+}
+
+x = [100, 200, 300, 400]
+# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
+
+colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
+
+metric_data = {
+ "FID50K" : fid_data,
+ # "SFID" : sfid_data,
+ "InceptionScore" : is_data,
+ "Precision" : pr_data,
+ "Recall" : recall_data,
+}
+
+for key, data in metric_data.items():
+ # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
+ for i, (name, v) in enumerate(data.items()):
+ name = name.replace("encoder", "En")
+ name = name.replace("decoder", "De")
+ plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
+ plt.legend(fontsize="14")
+ plt.xticks([100, 150, 200, 250, 300, 350, 400])
+ plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
+ plt.ylabel(key, weight="bold")
+ plt.xlabel("Training iterations(K steps)", weight="bold")
+ plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
+ plt.close()
\ No newline at end of file
diff --git a/tools/figures/base.py b/tools/figures/base.py
new file mode 100644
index 0000000..46acd3c
--- /dev/null
+++ b/tools/figures/base.py
@@ -0,0 +1,57 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+
+
+fid_data = {
+ "4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
+ "6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
+ "8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
+}
+
+sfid_data = {
+ "4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
+ "6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
+ "8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
+}
+
+is_data = {
+ "4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
+ "6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
+ "8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
+}
+
+pr_data = {
+ "4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
+ "6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
+ "8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
+}
+
+recall_data = {
+ "4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
+ "6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
+ "8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
+}
+
+x = [100, 200, 300, 400]
+colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
+metric_data = {
+ "FID" : fid_data,
+ # "SFID" : sfid_data,
+ "InceptionScore" : is_data,
+ "Precision" : pr_data,
+ "Recall" : recall_data,
+}
+
+for key, data in metric_data.items():
+ for i, (name, v) in enumerate(data.items()):
+ name = name.replace("encoder", "En")
+ name = name.replace("decoder", "De")
+ plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
+ plt.legend()
+ plt.xticks(x)
+ plt.ylabel(key, weight="bold")
+ plt.xlabel("Training iterations(K steps)", weight="bold")
+ plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
+ plt.close()
\ No newline at end of file
diff --git a/tools/figures/cfg.py b/tools/figures/cfg.py
new file mode 100644
index 0000000..4cd8855
--- /dev/null
+++ b/tools/figures/cfg.py
@@ -0,0 +1,32 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+cfg_data = {
+ "[0, 1]":{
+ "cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
+ "FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
+ },
+ "[0.2, 1]":{
+ "cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
+ "FID": [5.87, 4.44, 3.96, 4.01, 4.26]
+ },
+ "[0.3, 1]":{
+ "cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
+ "FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
+ },
+ "[0.35, 1]":{
+ "cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
+ "FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
+ }
+}
+
+colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
+
+for i, (name, data) in enumerate(cfg_data.items()):
+ plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
+
+plt.title("Classifer-free guidance with intervals", weight="bold")
+plt.ylabel("FID10K", weight="bold")
+plt.xlabel("CFG values", weight="bold")
+plt.legend()
+plt.savefig("./output/cfg.pdf", bbox_inches="tight")
\ No newline at end of file
diff --git a/tools/figures/feat_vis.py b/tools/figures/feat_vis.py
new file mode 100644
index 0000000..55f2045
--- /dev/null
+++ b/tools/figures/feat_vis.py
@@ -0,0 +1,42 @@
+import torch
+
+states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
+states = states.permute(1, 2, 0, 3)
+print(states.shape)
+states = states.view(-1, 49, 1152)
+states = torch.nn.functional.normalize(states, dim=-1)
+sim = torch.bmm(states, states.transpose(1, 2))
+mean_sim = torch.mean(sim, dim=0, keepdim=False)
+
+mean_sim = mean_sim.numpy()
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+timesteps = np.linspace(0, 1, 5)
+# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
+cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
+plt.imshow(mean_sim, cmap="inferno")
+plt.xticks([])
+plt.yticks([])
+# plt.show()
+plt.colorbar()
+plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
+# cos_sim = torch.nn.functional.cosine_similarity(states, states)
+
+
+# for i in range(49):
+# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
+# cos_sim = cos_sim.min()
+# print(cos_sim)
+# state = torch.max(states, dim=-1)[1]
+# # state = torch.softmax(state, dim=-1)
+# state = state.view(-1, 16, 16)
+#
+# state = state.numpy()
+#
+# import numpy as np
+# import matplotlib.pyplot as plt
+# for i in range(0, 49):
+# print(i)
+# plt.imshow(state[i])
+# plt.savefig("./output2/{}.png".format(i))
\ No newline at end of file
diff --git a/tools/figures/large++.py b/tools/figures/large++.py
new file mode 100644
index 0000000..070d13f
--- /dev/null
+++ b/tools/figures/large++.py
@@ -0,0 +1,63 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+is_data = {
+ "10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
+ "12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
+ "16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
+ "20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
+}
+
+fid_data = {
+ "10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
+ "12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
+ "16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
+ "20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
+}
+
+sfid_data = {
+ "10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
+ "12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
+ "16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
+ "20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
+}
+
+pr_data = {
+ "10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
+ "12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
+ "16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
+ "20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
+}
+
+recall_data = {
+ "10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
+ "12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
+ "16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
+ "20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
+}
+
+x = [100, 200, 300, 400]
+# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
+colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
+
+metric_data = {
+ "FID50K" : fid_data,
+ # "SFID" : sfid_data,
+ "InceptionScore" : is_data,
+ "Precision" : pr_data,
+ "Recall" : recall_data,
+}
+
+for key, data in metric_data.items():
+ # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
+ for i, (name, v) in enumerate(data.items()):
+ name = name.replace("encoder", "En")
+ name = name.replace("decoder", "De")
+ plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
+ plt.legend(fontsize="14")
+ plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
+ plt.xticks([100, 150, 200, 250, 300, 350, 400])
+ plt.ylabel(key, weight="bold")
+ plt.xlabel("Training iterations(K steps)", weight="bold")
+ plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
+ plt.close()
\ No newline at end of file
diff --git a/tools/figures/log_snr.py b/tools/figures/log_snr.py
new file mode 100644
index 0000000..f9d1b28
--- /dev/null
+++ b/tools/figures/log_snr.py
@@ -0,0 +1,18 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+t = np.linspace(0.001, 0.999, 100)
+def snr(t):
+ return np.log((1-t)/t)
+def pds(t):
+ return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
+print(pds(t))
+plt.figure(figsize=(16, 4))
+plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
+# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
+plt.ylabel("log-SNR", weight="bold")
+plt.xlabel("Timesteps", weight="bold")
+plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
+plt.gca().invert_xaxis()
+plt.show()
+# plt.savefig("output/logsnr.pdf", bbox_inches='tight')
\ No newline at end of file
diff --git a/tools/figures/output/base++_FID.pdf b/tools/figures/output/base++_FID.pdf
new file mode 100644
index 0000000..4aa3a98
Binary files /dev/null and b/tools/figures/output/base++_FID.pdf differ
diff --git a/tools/figures/output/base++_FID50K.pdf b/tools/figures/output/base++_FID50K.pdf
new file mode 100644
index 0000000..779138f
Binary files /dev/null and b/tools/figures/output/base++_FID50K.pdf differ
diff --git a/tools/figures/output/base++_InceptionScore.pdf b/tools/figures/output/base++_InceptionScore.pdf
new file mode 100644
index 0000000..03eceb5
Binary files /dev/null and b/tools/figures/output/base++_InceptionScore.pdf differ
diff --git a/tools/figures/output/base++_Precision.pdf b/tools/figures/output/base++_Precision.pdf
new file mode 100644
index 0000000..38f410a
Binary files /dev/null and b/tools/figures/output/base++_Precision.pdf differ
diff --git a/tools/figures/output/base++_Recall.pdf b/tools/figures/output/base++_Recall.pdf
new file mode 100644
index 0000000..e01ed34
Binary files /dev/null and b/tools/figures/output/base++_Recall.pdf differ
diff --git a/tools/figures/output/base_FID.pdf b/tools/figures/output/base_FID.pdf
new file mode 100644
index 0000000..5541a6b
Binary files /dev/null and b/tools/figures/output/base_FID.pdf differ
diff --git a/tools/figures/output/base_InceptionScore.pdf b/tools/figures/output/base_InceptionScore.pdf
new file mode 100644
index 0000000..ad8ae23
Binary files /dev/null and b/tools/figures/output/base_InceptionScore.pdf differ
diff --git a/tools/figures/output/base_Precision.pdf b/tools/figures/output/base_Precision.pdf
new file mode 100644
index 0000000..dc16562
Binary files /dev/null and b/tools/figures/output/base_Precision.pdf differ
diff --git a/tools/figures/output/base_Recall.pdf b/tools/figures/output/base_Recall.pdf
new file mode 100644
index 0000000..7164c9e
Binary files /dev/null and b/tools/figures/output/base_Recall.pdf differ
diff --git a/tools/figures/output/cfg.pdf b/tools/figures/output/cfg.pdf
new file mode 100644
index 0000000..90eec86
Binary files /dev/null and b/tools/figures/output/cfg.pdf differ
diff --git a/tools/figures/output/large++_FID.pdf b/tools/figures/output/large++_FID.pdf
new file mode 100644
index 0000000..ec04806
Binary files /dev/null and b/tools/figures/output/large++_FID.pdf differ
diff --git a/tools/figures/output/large++_FID50K.pdf b/tools/figures/output/large++_FID50K.pdf
new file mode 100644
index 0000000..b620e46
Binary files /dev/null and b/tools/figures/output/large++_FID50K.pdf differ
diff --git a/tools/figures/output/large++_InceptionScore.pdf b/tools/figures/output/large++_InceptionScore.pdf
new file mode 100644
index 0000000..0c44d4b
Binary files /dev/null and b/tools/figures/output/large++_InceptionScore.pdf differ
diff --git a/tools/figures/output/large++_Precision.pdf b/tools/figures/output/large++_Precision.pdf
new file mode 100644
index 0000000..f5698ad
Binary files /dev/null and b/tools/figures/output/large++_Precision.pdf differ
diff --git a/tools/figures/output/large++_Recall.pdf b/tools/figures/output/large++_Recall.pdf
new file mode 100644
index 0000000..2671f07
Binary files /dev/null and b/tools/figures/output/large++_Recall.pdf differ
diff --git a/tools/figures/output/logsnr.pdf b/tools/figures/output/logsnr.pdf
new file mode 100644
index 0000000..3d46a63
Binary files /dev/null and b/tools/figures/output/logsnr.pdf differ
diff --git a/tools/figures/output/mean_sim.png b/tools/figures/output/mean_sim.png
new file mode 100644
index 0000000..f4b7967
Binary files /dev/null and b/tools/figures/output/mean_sim.png differ
diff --git a/tools/figures/output/sota.pdf b/tools/figures/output/sota.pdf
new file mode 100644
index 0000000..80b7263
Binary files /dev/null and b/tools/figures/output/sota.pdf differ
diff --git a/tools/figures/output/timeshift.pdf b/tools/figures/output/timeshift.pdf
new file mode 100644
index 0000000..b2cb2f9
Binary files /dev/null and b/tools/figures/output/timeshift.pdf differ
diff --git a/tools/figures/output/timeshift_fid.pdf b/tools/figures/output/timeshift_fid.pdf
new file mode 100644
index 0000000..1b3b7c7
Binary files /dev/null and b/tools/figures/output/timeshift_fid.pdf differ
diff --git a/tools/figures/sota.py b/tools/figures/sota.py
new file mode 100644
index 0000000..58b0ec9
--- /dev/null
+++ b/tools/figures/sota.py
@@ -0,0 +1,95 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+data = {
+ "SiT-XL/2" : {
+ "size": 675,
+ "epochs": 1400,
+ "FID": 2.06,
+ "color": "#ff99c8"
+ },
+ "DiT-XL/2" : {
+ "size": 675,
+ "epochs": 1400,
+ "FID": 2.27,
+ "color": "#fcf6bd"
+ },
+ "REPA-XL/2" : {
+ "size": 675,
+ "epochs": 800,
+ "FID": 1.42,
+ "color": "#d0f4de"
+ },
+ # "MAR-H" : {
+ # "size": 973,
+ # "epochs": 800,
+ # "FID": 1.55,
+ # },
+ "MDTv2" : {
+ "size": 675,
+ "epochs": 920,
+ "FID": 1.58,
+ "color": "#e4c1f9"
+ },
+ # "VAVAE+LightningDiT" : {
+ # "size": 675,
+ # "epochs": [64, 800],
+ # "FID": [2.11, 1.35],
+ # },
+ "DDT-XL/2": {
+ "size": 675,
+ "epochs": [80, 256],
+ "FID": [1.52, 1.31],
+ "color": "#38a3a5"
+ },
+ "DDT-L/2": {
+ "size": 400,
+ "epochs": 80,
+ "FID": 1.64,
+ "color": "#5bc0be"
+ },
+}
+
+fig = plt.figure()
+ax = fig.add_subplot(1, 1, 1)
+for k, spec in data.items():
+ plt.scatter(
+ # spec["size"],
+ spec["epochs"],
+ spec["FID"],
+ label=k,
+ marker="o",
+ s=spec["size"],
+ color=spec["color"],
+ )
+ x = spec["epochs"]
+ y = spec["FID"]
+ if isinstance(spec["FID"], list):
+ x = spec["epochs"][-1]
+ y = spec["FID"][-1]
+ plt.plot(
+ spec["epochs"],
+ spec["FID"],
+ color=spec["color"],
+ linestyle="dotted",
+ linewidth=4
+ )
+ # plt.annotate("",
+ # xytext=(spec["epochs"][0], spec["FID"][0]),
+ # xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
+ plt.text(x+80, y-0.05, k, fontsize=13)
+
+plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
+# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
+
+plt.annotate("",
+ xy=(700, 1.42), xytext=(200, 1.42),
+ arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
+ )
+ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
+plt.gca().set_xlim(0, 1800)
+plt.gca().set_ylim(1.15, 2.5)
+plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
+plt.xlabel("Training Epochs", weight="bold")
+plt.ylabel("FID50K on ImageNet256x256", weight="bold")
+plt.savefig("output/sota.pdf", bbox_inches="tight")
\ No newline at end of file
diff --git a/tools/figures/timeshift.py b/tools/figures/timeshift.py
new file mode 100644
index 0000000..766fac7
--- /dev/null
+++ b/tools/figures/timeshift.py
@@ -0,0 +1,26 @@
+import scipy
+import numpy as np
+import matplotlib.pyplot as plt
+
+def timeshift(t, s=1.0):
+ return t/(t+(1-t)*s)
+
+# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
+colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
+# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
+t = np.linspace(0, 1, 100)
+shifts = [1.0, 1.5, 2, 3]
+for i , shift in enumerate(shifts):
+ plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
+
+# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
+# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
+# plt.title("Respaced timesteps with various shift value", weight="bold")
+# plt.gca().set_xlim(0, 1.0)
+# plt.gca().set_ylim(0, 1.0)
+plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
+
+plt.ylabel("Respaced Timesteps", weight="bold")
+plt.xlabel("Uniform Timesteps", weight="bold")
+plt.legend(loc="upper left", fontsize="12")
+plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)
\ No newline at end of file
diff --git a/tools/figures/timeshift_fid.py b/tools/figures/timeshift_fid.py
new file mode 100644
index 0000000..07a591c
--- /dev/null
+++ b/tools/figures/timeshift_fid.py
@@ -0,0 +1,29 @@
+import scipy
+import numpy as np
+import matplotlib.pyplot as plt
+
+def timeshift(t, s=1.0):
+ return t/(t+(1-t)*s)
+
+data = {
+ "shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
+ "shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
+ "shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
+ "shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
+}
+# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
+
+# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
+
+colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
+steps = [5, 6, 7, 8, 9, 10, 12]
+for i ,(k, v)in enumerate(data.items()):
+ plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
+
+# plt.title("FID50K of different steps of different timeshift", weight="bold")
+plt.ylabel("FID50K", weight="bold")
+plt.xlabel("Num of inference steps", weight="bold")
+plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
+# plt.legend()
+# plt.legend()
+plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)
\ No newline at end of file
diff --git a/tools/fm_images.py b/tools/fm_images.py
new file mode 100644
index 0000000..6657523
--- /dev/null
+++ b/tools/fm_images.py
@@ -0,0 +1,21 @@
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+PATH = "/mnt/bn/wangshuai6/neural_sampling_workdirs/expbaseline_adam2_timeshift1.5"
+
+import os
+import pathlib
+PATH = pathlib.Path(PATH)
+images = []
+
+# find images
+for ext in IMG_EXTENSIONS:
+ images.extend(PATH.rglob(ext))
+
+for image in images:
+ os.system(f"rm -f {image}")
+
diff --git a/tools/mm.py b/tools/mm.py
new file mode 100644
index 0000000..25bba93
--- /dev/null
+++ b/tools/mm.py
@@ -0,0 +1,23 @@
+import torch
+import time
+import torch.nn as nn
+import accelerate
+
+
+if __name__ == "__main__":
+ model = nn.Linear(512, 512)
+ for p in model.parameters():
+ p.requires_grad = False
+ accelerator = accelerate.Accelerator()
+ model = accelerator.prepare_model(model)
+ model.to(accelerator.device)
+ data = torch.randn(1024, 512).to(accelerator.device)
+ while True:
+ time.sleep(0.01)
+ accelerator.wait_for_everyone()
+ if torch.cuda.utilization() < 1.5:
+ with torch.no_grad():
+ model(data)
+ else:
+ time.sleep(1)
+ # print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}")
\ No newline at end of file
diff --git a/tools/sigmoid.py b/tools/sigmoid.py
new file mode 100644
index 0000000..90cec55
--- /dev/null
+++ b/tools/sigmoid.py
@@ -0,0 +1,20 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+def lw(x, b=0):
+ x = np.clip(x, a_min=0.001, a_max=0.999)
+ snr = x/(1-x)
+ logsnr = np.log(snr)
+ # print(logsnr)
+ # return logsnr
+ weight = 1 / (1 + np.exp(-logsnr - b))#*(1-x)**2
+ return weight #/weight.max()
+
+x = np.arange(0.2, 0.8, 0.001)
+print(1/(x*(1-x)))
+for b in [0, 1, 2, 3]:
+ y = lw(x, b)
+ plt.plot(x, y, label=f"b={b}")
+plt.legend()
+plt.show()
\ No newline at end of file
diff --git a/tools/vae2dino.py b/tools/vae2dino.py
new file mode 100644
index 0000000..136f5ff
--- /dev/null
+++ b/tools/vae2dino.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+import timm
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from torchvision.transforms import Normalize
+import copy
+
+class DINOv2(nn.Module):
+ def __init__(self, weight_path:str):
+ super(DINOv2, self).__init__()
+ self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.precomputed_pos_embed = dict()
+
+ def fetch_pos(self, h, w):
+ key = (h, w)
+ if key in self.precomputed_pos_embed:
+ return self.precomputed_pos_embed[key]
+ value = timm.layers.pos_embed.resample_abs_pos_embed(
+ self.pos_embed.data, [h, w],
+ )
+ self.precomputed_pos_embed[key] = value
+ return value
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
+ self.encoder.pos_embed.data = pos_embed_data
+ feature = self.encoder.forward_features(x)['x_norm_patchtokens']
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ return feature
+
+class MAE(nn.Module):
+ def __init__(self, model_id, weight_path:str):
+ super(MAE, self).__init__()
+ if os.path.isdir(weight_path):
+ weight_path = os.path.join(weight_path, "pytorch_model.bin")
+ self.encoder = timm.create_model(
+ model_id,
+ checkpoint_path=weight_path,
+ num_classes=0,
+ )
+ self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
+ self.encoder.head = torch.nn.Identity()
+ self.patch_size = self.encoder.patch_embed.patch_size
+ self.shifts = nn.Parameter(torch.tensor([0.0
+ ]), requires_grad=False)
+ self.scales = nn.Parameter(torch.tensor([1.0
+ ]), requires_grad=False)
+
+ def forward(self, x):
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
+ x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
+ b, c, h, w = x.shape
+ patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
+ feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
+ feature = feature.transpose(1, 2)
+ feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
+ feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
+ return feature
+
+
+from diffusers import AutoencoderKL
+
+from torchvision.datasets import ImageFolder, ImageNet
+
+import cv2
+import os
+import numpy as np
+from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
+
+from PIL import Image
+import torch
+import torchvision.transforms as tvtf
+
+IMG_EXTENSIONS = (
+ "*.png",
+ "*.JPEG",
+ "*.jpeg",
+ "*.jpg"
+)
+
+class NewImageFolder(ImageFolder):
+ def __getitem__(self, item):
+ path, target = self.samples[item]
+ sample = self.loader(path)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return sample, target, path
+
+
+import time
+class CenterCrop:
+ def __init__(self, size):
+ self.size = size
+ def __call__(self, image):
+ def center_crop_arr(pil_image, image_size):
+ """
+ Center cropping implementation from ADM.
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
+ """
+ while min(*pil_image.size) >= 2 * image_size:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = image_size / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image)
+ crop_y = (arr.shape[0] - image_size) // 2
+ crop_x = (arr.shape[1] - image_size) // 2
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
+
+ return center_crop_arr(image, self.size)
+
+def save(obj, path):
+ dirname = os.path.dirname(path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ torch.save(obj, f'{path}')
+
+if __name__ == "__main__":
+ torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
+ os.environ["TORCH_HOME"] = torch_hub_dir
+ torch.hub.set_dir(torch_hub_dir)
+
+ for split in ['train']:
+ train = split == 'train'
+ transforms = tvtf.Compose([
+ CenterCrop(256),
+ tvtf.ToTensor(),
+ tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
+ ])
+ dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
+ B = 4
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
+ # dino = DINOv2("dinov2_vitb14")
+ dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
+ # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
+ from accelerate import Accelerator
+
+ accelerator = Accelerator()
+ dino, dataloader = accelerator.prepare(dino, dataloader)
+ rank = accelerator.process_index
+
+ acc_mean = torch.zeros((768, ), device=accelerator.device)
+ acc_num = 0
+ with torch.no_grad():
+ for i, (images, labels, path_list) in enumerate(dataloader):
+ acc_num += len(images)
+ feature = dino(images)
+ stds = torch.std(feature, dim=[0, 2, 3]).tolist()
+ for std in stds:
+ print("{:.2f},".format(std), end='')
+ print()
+ means = torch.mean(feature, dim=[0, 2, 3]).tolist()
+ for mean in means:
+ print("{:.2f},".format(mean), end='')
+ break
+
+ accelerator.wait_for_everyone()
\ No newline at end of file
diff --git a/tools/vis_timeshift.py b/tools/vis_timeshift.py
new file mode 100644
index 0000000..496b922
--- /dev/null
+++ b/tools/vis_timeshift.py
@@ -0,0 +1,23 @@
+import scipy
+import numpy as np
+import matplotlib.pyplot as plt
+
+def timeshift(t, s=1.0):
+ return t/(t+(1-t)*s)
+
+def gaussian(t):
+ gs = 1+scipy.special.erf((t-t.mean())/t.std())
+
+def rs2(t, s=2.0):
+ factor1 = 1.0 #s/(s+(1-s)*t)**2
+ factor2 = np.log(t.clip(0.001, 0.999)/(1-t).clip(0.001, 0.999))
+ return factor1*factor2
+
+
+t = np.linspace(0, 1, 100)
+# plt.plot(t, timeshift(t, 1.0))
+respaced_t = timeshift(t, s=5)
+delats = (respaced_t[1:] - respaced_t[:-1])
+# plt.plot(t, timeshift(t, 1.5))
+plt.plot(rs2(t))
+plt.show()
\ No newline at end of file