diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/DDT.iml b/.idea/DDT.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/DDT.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..b19c6c8 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/configs/repa_flatten_condit22_fixt_xl.yaml b/configs/repa_flatten_condit22_fixt_xl.yaml new file mode 100644 index 0000000..6f3a6ce --- /dev/null +++ b/configs/repa_flatten_condit22_fixt_xl.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_cond_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 2.0 + timeshift: 1.5 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + last_step: 0.04 + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_flatten_condit22_fixt_xl512.yaml b/configs/repa_flatten_condit22_fixt_xl512.yaml new file mode 100644 index 0000000..ce59338 --- /dev/null +++ b/configs/repa_flatten_condit22_fixt_xl512.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_cond_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 3.0 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + timeshift: 1.0 + last_step: 0.04 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet512 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 512 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 32 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 64 + - 64 \ No newline at end of file diff --git a/configs/repa_flatten_dit_fixt_large.yaml b/configs/repa_flatten_dit_fixt_large.yaml new file mode 100644 index 0000000..7dc9611 --- /dev/null +++ b/configs/repa_flatten_dit_fixt_large.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_dit_fixt_large +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1024 + num_blocks: 24 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_flatten_dit_fixt_xl.yaml b/configs/repa_flatten_dit_fixt_xl.yaml new file mode 100644 index 0000000..2bc8606 --- /dev/null +++ b/configs/repa_flatten_dit_fixt_xl.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_dit_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..85bf7cb --- /dev/null +++ b/main.py @@ -0,0 +1,86 @@ +import time +from typing import Any, Union + +import pylab as pl + +from src.utils.patch_bugs import * + +import os +import torch +from lightning import Trainer, LightningModule +from src.lightning_data import DataModule +from src.lightning_model import LightningModel +from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback + +import logging +logger = logging.getLogger("lightning.pytorch") +# log_path = os.path.join( f"log.txt") +# logger.addHandler(logging.FileHandler(log_path)) + +class ReWriteRootSaveConfigCallback(SaveConfigCallback): + def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + stamp = time.strftime('%y%m%d%H%M') + file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml") + self.parser.save( + self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + + +class ReWriteRootDirCli(LightningCLI): + def before_instantiate_classes(self) -> None: + super().before_instantiate_classes() + config_trainer = self._get(self.config, "trainer", default={}) + + # predict path & logger check + if self.subcommand == "predict": + config_trainer.logger = None + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + class TagsClass: + def __init__(self, exp:str): + ... + parser.add_class_arguments(TagsClass, nested_key="tags") + + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_default_arguments_to_parser(parser) + parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),) + parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),) + + def instantiate_trainer(self, **kwargs: Any) -> Trainer: + config_trainer = self._get(self.config_init, "trainer", default={}) + default_root_dir = config_trainer.get("default_root_dir", None) + + if default_root_dir is None: + default_root_dir = os.path.join(os.getcwd(), "workdirs") + + dirname = "" + for v, k in self._get(self.config, "tags", default={}).items(): + dirname += f"{v}_{k}" + default_root_dir = os.path.join(default_root_dir, dirname) + is_resume = self._get(self.config_init, "ckpt_path", default=None) + if os.path.exists(default_root_dir) and "debug" not in default_root_dir: + if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume: + raise FileExistsError(f"{default_root_dir} already exists") + + config_trainer.default_root_dir = default_root_dir + trainer = super().instantiate_trainer(**kwargs) + if trainer.is_global_zero: + os.makedirs(default_root_dir, exist_ok=True) + return trainer + + def instantiate_classes(self) -> None: + torch_hub_dir = self._get(self.config, "torch_hub_dir") + huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir") + if huggingface_cache_dir is not None: + os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir + if torch_hub_dir is not None: + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + super().instantiate_classes() + +if __name__ == "__main__": + + cli = ReWriteRootDirCli(LightningModel, DataModule, + auto_configure_optimizers=False, + save_config_callback=ReWriteRootSaveConfigCallback, + save_config_kwargs={"overwrite": True}) \ No newline at end of file diff --git a/main.sh b/main.sh new file mode 100644 index 0000000..39f803e --- /dev/null +++ b/main.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export NCCL_HOSTID=${MY_POD_NAME} +export MASTER_ADDR=${ARNOLD_WORKER_0_HOST} +export MASTER_PORT=${ARNOLD_WORKER_0_PORT} +export NODE_RANK=${ARNOLD_ID} +export NUM_NODES=${ARNOLD_WORKER_NUM} + +python3 main.py fit -c $1 --trainer.num_nodes $NUM_NODES +# for pid in $(ps -ef | grep "yaml" | grep -v "grep" | awk '{print $2}'); do kill -9 $pid; done \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9061a84 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +lightning==2.5.0.post0 +omegaconf==2.3.0 +jsonargparse[signatures]>=4.27.7 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/callbacks/grad.py b/src/callbacks/grad.py new file mode 100644 index 0000000..f9155b6 --- /dev/null +++ b/src/callbacks/grad.py @@ -0,0 +1,22 @@ +import torch +import lightning.pytorch as pl +from lightning.pytorch.utilities import grad_norm +from torch.optim import Optimizer + +class GradientMonitor(pl.Callback): + """Logs the gradient norm""" + + def __init__(self, norm_type: int = 2): + norm_type = float(norm_type) + if norm_type <= 0: + raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") + self.norm_type = norm_type + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + optimizer: Optimizer + ) -> None: + norms = grad_norm(pl_module, norm_type=self.norm_type) + max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() + pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) \ No newline at end of file diff --git a/src/callbacks/model_checkpoint.py b/src/callbacks/model_checkpoint.py new file mode 100644 index 0000000..019454e --- /dev/null +++ b/src/callbacks/model_checkpoint.py @@ -0,0 +1,25 @@ +import os.path +from typing import Optional, Dict, Any + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from soupsieve.util import lower + + +class CheckpointHook(ModelCheckpoint): + """Save checkpoint with only the incremental part of the model""" + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.dirpath = trainer.default_root_dir + self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") + pl_module.strict_loading = False + + def on_save_checkpoint( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any] + ) -> None: + del checkpoint["callbacks"] + + # def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + # if not "debug" in self.exception_ckpt_path: + # trainer.save_checkpoint(self.exception_ckpt_path) \ No newline at end of file diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py new file mode 100644 index 0000000..c6cd32b --- /dev/null +++ b/src/callbacks/save_images.py @@ -0,0 +1,105 @@ +import lightning.pytorch as pl +from lightning.pytorch import Callback + + +import os.path +import numpy +from PIL import Image +from typing import Sequence, Any, Dict +from concurrent.futures import ThreadPoolExecutor + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning_utilities.core.rank_zero import rank_zero_info + +def process_fn(image, path): + Image.fromarray(image).save(path) + +class SaveImagesHook(Callback): + def __init__(self, save_dir="val", max_save_num=0, compressed=True): + self.save_dir = save_dir + self.max_save_num = max_save_num + self.compressed = compressed + + def save_start(self, target_dir): + self.target_dir = target_dir + self.executor_pool = ThreadPoolExecutor(max_workers=8) + if not os.path.exists(self.target_dir): + os.makedirs(self.target_dir, exist_ok=True) + else: + if os.listdir(target_dir) and "debug" not in str(target_dir): + raise FileExistsError(f'{self.target_dir} already exists and not empty!') + self.samples = [] + self._have_saved_num = 0 + rank_zero_info(f"Save images to {self.target_dir}") + + def save_image(self, images, filenames): + images = images.permute(0, 2, 3, 1).cpu().numpy() + for sample, filename in zip(images, filenames): + if isinstance(filename, Sequence): + filename = filename[0] + path = f'{self.target_dir}/{filename}' + if self._have_saved_num >= self.max_save_num: + break + self.executor_pool.submit(process_fn, sample, path) + self._have_saved_num += 1 + + def process_batch( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: STEP_OUTPUT, + batch: Any, + ) -> None: + b, c, h, w = samples.shape + xT, y, metadata = batch + all_samples = pl_module.all_gather(samples).view(-1, c, h, w) + self.save_image(samples, metadata) + if trainer.is_global_zero: + all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() + self.samples.append(all_samples) + + def save_end(self): + if self.compressed and len(self.samples) > 0: + samples = numpy.concatenate(self.samples) + numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) + self.executor_pool.shutdown(wait=True) + self.samples = [] + self.target_dir = None + self._have_saved_num = 0 + self.executor_pool = None + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") + self.save_start(target_dir) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, outputs, batch) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") + self.save_start(target_dir) + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, samples, batch) + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() \ No newline at end of file diff --git a/src/callbacks/simple_ema.py b/src/callbacks/simple_ema.py new file mode 100644 index 0000000..28bf476 --- /dev/null +++ b/src/callbacks/simple_ema.py @@ -0,0 +1,79 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +import threading +import lightning.pytorch as pl +from lightning.pytorch import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from src.utils.copy import swap_tensors + +class SimpleEMA(Callback): + def __init__(self, net:nn.Module, ema_net:nn.Module, + decay: float = 0.9999, + every_n_steps: int = 1, + eval_original_model:bool = False + ): + super().__init__() + self.decay = decay + self.every_n_steps = every_n_steps + self.eval_original_model = eval_original_model + self._stream = torch.cuda.Stream() + + self.net_params = list(net.parameters()) + self.ema_params = list(ema_net.parameters()) + + def swap_model(self): + for ema_p, p, in zip(self.ema_params, self.net_params): + swap_tensors(ema_p, p) + + def ema_step(self): + @torch.no_grad() + def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + if self._stream is not None: + self._stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._stream): + ema_update(self.ema_params, self.net_params, self.decay) + + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if trainer.global_step % self.every_n_steps == 0: + self.ema_step() + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + + def state_dict(self) -> Dict[str, Any]: + return { + "decay": self.decay, + "every_n_steps": self.every_n_steps, + "eval_original_model": self.eval_original_model, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.eval_original_model = state_dict["eval_original_model"] + diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1 @@ + diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/dataset/celeba.py b/src/data/dataset/celeba.py new file mode 100644 index 0000000..30f5d3f --- /dev/null +++ b/src/data/dataset/celeba.py @@ -0,0 +1,11 @@ +from typing import Callable +from torchvision.datasets import CelebA + + +class LocalDataset(CelebA): + def __init__(self, root:str, ): + super(LocalDataset, self).__init__(root, "train") + + def __getitem__(self, idx): + data = super().__getitem__(idx) + return data \ No newline at end of file diff --git a/src/data/dataset/imagenet.py b/src/data/dataset/imagenet.py new file mode 100644 index 0000000..59d0547 --- /dev/null +++ b/src/data/dataset/imagenet.py @@ -0,0 +1,82 @@ +import torch +from PIL import Image +from torchvision.datasets import ImageFolder +from torchvision.transforms.functional import to_tensor +from torchvision.transforms import Normalize + +from src.data.dataset.metric_dataset import CenterCrop + +class LocalCachedDataset(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.cache_root = None + + def load_latent(self, latent_path): + pk_data = torch.load(latent_path) + mean = pk_data['mean'].to(torch.float32) + logvar = pk_data['logvar'].to(torch.float32) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + latent = mean + torch.randn_like(mean) * std + return latent + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + latent_path = image_path.replace(self.root, self.cache_root) + ".pt" + + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + if self.cache_root is not None: + latent = self.load_latent(latent_path) + else: + latent = raw_image + return raw_image, latent, target + +class ImageNet256(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 256) + self.cache_root = root + "_256_latent" + +class ImageNet512(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 512) + self.cache_root = root + "_512_latent" + +class PixImageNet(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + + normalized_image = self.normalize(raw_image) + return raw_image, normalized_image, target + +class PixImageNet64(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 64) + +class PixImageNet128(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 128) + + +class PixImageNet256(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 256) + +class PixImageNet512(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 512) + + + + + diff --git a/src/data/dataset/metric_dataset.py b/src/data/dataset/metric_dataset.py new file mode 100644 index 0000000..cbe7d66 --- /dev/null +++ b/src/data/dataset/metric_dataset.py @@ -0,0 +1,82 @@ +import pathlib + +import torch +import random +import numpy as np +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset + +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + + +from PIL import Image +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +def test_collate(batch): + return torch.stack(batch) + +class ImageDataset(Dataset): + def __init__(self, root, image_size=(224, 224)): + self.root = pathlib.Path(root) + images = [] + for ext in IMG_EXTENSIONS: + images.extend(self.root.rglob(ext)) + random.shuffle(images) + self.images = list(map(lambda x: str(x), images)) + self.transform = tvtf.Compose( + [ + CenterCrop(image_size[0]), + tvtf.ToTensor(), + tvtf.Lambda(lambda x: (x*255).to(torch.uint8)), + tvtf.Lambda(lambda x: x.expand(3, -1, -1)) + ] + ) + self.size = image_size + + def __getitem__(self, idx): + try: + image = Image.open(self.images[idx]) + image = self.transform(image) + except Exception as e: + print(self.images[idx]) + image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8) + + # print(image) + metadata = dict( + path = self.images[idx], + root = self.root, + ) + return image #, metadata + + def __len__(self): + return len(self.images) \ No newline at end of file diff --git a/src/data/dataset/randn.py b/src/data/dataset/randn.py new file mode 100644 index 0000000..f9ec772 --- /dev/null +++ b/src/data/dataset/randn.py @@ -0,0 +1,41 @@ +import os.path +import random + +import torch +from torch.utils.data import Dataset + + + +class RandomNDataset(Dataset): + def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ): + self.selected_classes = selected_classes + if selected_classes is not None: + num_classes = len(selected_classes) + max_num_instances = 10*num_classes + self.num_classes = num_classes + self.seeds = seeds + if seeds is not None: + self.max_num_instances = len(seeds)*num_classes + self.num_seeds = len(seeds) + else: + self.num_seeds = (max_num_instances + num_classes - 1) // num_classes + self.max_num_instances = self.num_seeds*num_classes + + self.latent_shape = latent_shape + + + def __getitem__(self, idx): + label = idx // self.num_seeds + if self.selected_classes: + label = self.selected_classes[label] + seed = random.randint(0, 1<<31) #idx % self.num_seeds + if self.seeds is not None: + seed = self.seeds[idx % self.num_seeds] + + # cls_dir = os.path.join(self.root, f"{label}") + filename = f"{label}_{seed}.png", + generator = torch.Generator().manual_seed(seed) + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + return latent, label, filename + def __len__(self): + return self.max_num_instances \ No newline at end of file diff --git a/src/data/var_training.py b/src/data/var_training.py new file mode 100644 index 0000000..de7fb74 --- /dev/null +++ b/src/data/var_training.py @@ -0,0 +1,145 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor +from typing import List +from PIL import Image +import torch +import random +import numpy as np +import copy +import torchvision.transforms.functional as tvtf +from src.models.vae import uint82fp + + +def center_crop_arr(pil_image, width, height): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = max(width / pil_image.size[0], height / pil_image.size[1]) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + arr = np.array(pil_image) + crop_y = random.randint(0, (arr.shape[0] - height)) + crop_x = random.randint(0, (arr.shape[1] - width)) + return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width]) + +def process_fn(width, height, data, hflip=0.5): + image, label = data + if random.uniform(0, 1) > hflip: # hflip + image = tvtf.hflip(image) + image = center_crop_arr(image, width, height) # crop + image = np.array(image).transpose(2, 0, 1) + return image, label + +class VARCandidate: + def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024): + self.aspect_ratio = aspect_ratio + self.width = int(width) + self.height = int(height) + self.buffer = buffer + self.max_buffer_size = max_buffer_size + + def add_sample(self, data): + self.buffer.append(data) + self.buffer = self.buffer[-self.max_buffer_size:] + + def ready(self, batch_size): + return len(self.buffer) >= batch_size + + def get_batch(self, batch_size): + batch = self.buffer[:batch_size] + self.buffer = self.buffer[batch_size:] + batch = [copy.deepcopy(b.result()) for b in batch] + x, y = zip(*batch) + x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0) + x = list(map(uint82fp, x)) + return x, y + +class VARTransformEngine: + def __init__(self, + base_image_size, + num_aspect_ratios, + min_aspect_ratio, + max_aspect_ratio, + num_workers = 8, + ): + self.base_image_size = base_image_size + self.num_aspect_ratios = num_aspect_ratios + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios) + self.aspect_ratios = self.aspect_ratios.tolist() + self.candidates_pool = [] + for i in range(self.num_aspect_ratios): + candidate = VARCandidate( + aspect_ratio=self.aspect_ratios[i], + width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16), + height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16), + buffer=[], + max_buffer_size=1024 + ) + self.candidates_pool.append(candidate) + self.default_candidate = VARCandidate( + aspect_ratio=1.0, + width=self.base_image_size, + height=self.base_image_size, + buffer=[], + max_buffer_size=1024, + ) + self.executor_pool = ProcessPoolExecutor(max_workers=num_workers) + self._prefill_count = 100 + + def find_candidate(self, data): + image = data[0] + aspect_ratio = image.size[0] / image.size[1] + min_distance = 1000000 + min_candidate = None + for candidate in self.candidates_pool: + dis = abs(aspect_ratio - candidate.aspect_ratio) + if dis < min_distance: + min_distance = dis + min_candidate = candidate + return min_candidate + + + def __call__(self, batch_data): + self._prefill_count -= 1 + if isinstance(batch_data[0], torch.Tensor): + batch_data[0] = batch_data[0].unbind(0) + + batch_data = list(zip(*batch_data)) + for data in batch_data: + candidate = self.find_candidate(data) + future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data) + candidate.add_sample(future) + if self._prefill_count >= 0: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + + batch_size = len(batch_data) + random.shuffle(self.candidates_pool) + for candidate in self.candidates_pool: + if candidate.ready(batch_size=batch_size): + return candidate.get_batch(batch_size=batch_size) + + # fallback to default 256 + for data in batch_data: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + return self.default_candidate.get_batch(batch_size=batch_size) \ No newline at end of file diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/diffusion/base/guidance.py b/src/diffusion/base/guidance.py new file mode 100644 index 0000000..07b4754 --- /dev/null +++ b/src/diffusion/base/guidance.py @@ -0,0 +1,60 @@ +import torch + +def simple_guidance_fn(out, cfg): + uncondition, condtion = out.chunk(2, dim=0) + out = uncondition + cfg * (condtion - uncondition) + return out + +def c3_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3]) + return out + +def c4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p05_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p10_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p15_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p20_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def p4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:]) + return out diff --git a/src/diffusion/base/sampling.py b/src/diffusion/base/sampling.py new file mode 100644 index 0000000..d8f9776 --- /dev/null +++ b/src/diffusion/base/sampling.py @@ -0,0 +1,31 @@ +from typing import Union, List + +import torch +import torch.nn as nn +from typing import Callable +from src.diffusion.base.scheduling import BaseScheduler + +class BaseSampler(nn.Module): + def __init__(self, + scheduler: BaseScheduler = None, + guidance_fn: Callable = None, + num_steps: int = 250, + guidance: Union[float, List[float]] = 1.0, + *args, + **kwargs + ): + super(BaseSampler, self).__init__() + self.num_steps = num_steps + self.guidance = guidance + self.guidance_fn = guidance_fn + self.scheduler = scheduler + + + def _impl_sampling(self, net, noise, condition, uncondition): + raise NotImplementedError + + def __call__(self, net, noise, condition, uncondition): + denoised = self._impl_sampling(net, noise, condition, uncondition) + return denoised + + diff --git a/src/diffusion/base/scheduling.py b/src/diffusion/base/scheduling.py new file mode 100644 index 0000000..05c7fb1 --- /dev/null +++ b/src/diffusion/base/scheduling.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor + +class BaseScheduler: + def alpha(self, t) -> Tensor: + ... + def sigma(self, t) -> Tensor: + ... + + def dalpha(self, t) -> Tensor: + ... + def dsigma(self, t) -> Tensor: + ... + + def dalpha_over_alpha(self, t) -> Tensor: + return self.dalpha(t) / self.alpha(t) + + def dsigma_mul_sigma(self, t) -> Tensor: + return self.dsigma(t)*self.sigma(t) + + def drift_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dalpha/(alpha + 1e-6) + + def diffuse_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2 + + def w(self, t): + return self.sigma(t) diff --git a/src/diffusion/base/training.py b/src/diffusion/base/training.py new file mode 100644 index 0000000..8f6d0e0 --- /dev/null +++ b/src/diffusion/base/training.py @@ -0,0 +1,29 @@ +import time + +import torch +import torch.nn as nn + +class BaseTrainer(nn.Module): + def __init__(self, + null_condition_p=0.1, + log_var=False, + ): + super(BaseTrainer, self).__init__() + self.null_condition_p = null_condition_p + self.log_var = log_var + + def preproprocess(self, raw_iamges, x, condition, uncondition): + bsz = x.shape[0] + if self.null_condition_p > 0: + mask = torch.rand((bsz), device=condition.device) < self.null_condition_p + mask = mask.expand_as(condition) + condition[mask] = uncondition[mask] + return raw_iamges, x, condition + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + raise NotImplementedError + + def __call__(self, net, ema_net, raw_images, x, condition, uncondition): + raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition) + return self._impl_trainstep(net, ema_net, raw_images, x, condition) + diff --git a/src/diffusion/ddpm/ddim_sampling.py b/src/diffusion/ddpm/ddim_sampling.py new file mode 100644 index 0000000..0db2a1d --- /dev/null +++ b/src/diffusion/ddpm/ddim_sampling.py @@ -0,0 +1,40 @@ +import torch +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + +import logging +logger = logging.getLogger(__name__) + +class DDIMSampler(BaseSampler): + def __init__( + self, + train_num_steps=1000, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.train_num_steps = train_num_steps + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device) + steps = torch.flip(steps, dims=[0]) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + t_cur = t_cur.repeat(batch_size) + t_next = t_next.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha = self.scheduler.alpha(t_cur) + sigma_next = self.scheduler.sigma(t_next) + alpha_next = self.scheduler.alpha(t_next) + cfg_x = torch.cat([x, x], dim=0) + t = t_cur.repeat(2) + out = net(cfg_x, t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + x0 = (x - sigma * out) / alpha + x = alpha_next * x0 + sigma_next * out + return x0 \ No newline at end of file diff --git a/src/diffusion/ddpm/scheduling.py b/src/diffusion/ddpm/scheduling.py new file mode 100644 index 0000000..aff1523 --- /dev/null +++ b/src/diffusion/ddpm/scheduling.py @@ -0,0 +1,102 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class DDPMScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.0001, + beta_max=0.02, + num_steps=1000, + ): + super().__init__() + self.beta_min = beta_min + self.beta_max = beta_max + self.num_steps = num_steps + + self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") + self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) + self.sigmas_table = 1-self.alphas_table + + + def beta(self, t) -> Tensor: + t = t.to(torch.long) + return self.betas_table[t].view(-1, 1, 1, 1) + + def alpha(self, t) -> Tensor: + t = t.to(torch.long) + return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 + + def sigma(self, t) -> Tensor: + t = t.to(torch.long) + return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + raise NotImplementedError("wrong usage") + + +class VPScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.1, + beta_max=20, + ): + super().__init__() + self.beta_min = beta_min + self.beta_d = beta_max - beta_min + def beta(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) + + def sigma(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t + return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def alpha(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t + return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + return self.diffuse_coefficient(t) + + + diff --git a/src/diffusion/ddpm/training.py b/src/diffusion/ddpm/training.py new file mode 100644 index 0000000..3e0d0ec --- /dev/null +++ b/src/diffusion/ddpm/training.py @@ -0,0 +1,83 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class VPTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t*self.train_max_t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - noise)**2 + + out = dict( + loss=loss.mean(), + ) + return out + + +class DDPMTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn: Callable = constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + t = torch.randint(0, self.train_max_t, (batch_size,)) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight * (out - noise) ** 2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/ddpm/vp_sampling.py b/src/diffusion/ddpm/vp_sampling.py new file mode 100644 index 0000000..250b32d --- /dev/null +++ b/src/diffusion/ddpm/vp_sampling.py @@ -0,0 +1,59 @@ +import torch + +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * +from typing import Callable + +def ode_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt + +def sde_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x) + +import logging +logger = logging.getLogger(__name__) + +class VPEulerSampler(BaseSampler): + def __init__( + self, + train_max_t=1000, + guidance_fn: Callable = None, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.guidance_fn = guidance_fn + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.train_max_t = train_max_t + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device) + steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + beta = self.scheduler.beta(t_cur) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition) + eps = self.guidance_fn(out, self.guidance) + if i < self.num_steps -1 : + x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0]) + x = self.step_fn(x, eps, beta, sigma, dt) + else: + x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py new file mode 100644 index 0000000..15d0c78 --- /dev/null +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -0,0 +1,107 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/sampling.py b/src/diffusion/flow_matching/sampling.py new file mode 100644 index 0000000..62bdd8b --- /dev/null +++ b/src/diffusion/flow_matching/sampling.py @@ -0,0 +1,179 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x + + +class HeunSampler(BaseSampler): + def __init__( + self, + scheduler: BaseScheduler = None, + w_scheduler: BaseScheduler = None, + exact_henu=False, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.exact_henu = exact_henu + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Henu sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + v_hat, s_hat = 0.0, 0.0 + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + t_hat = t_next + t_hat = t_hat.repeat(batch_size) + sigma_hat = self.scheduler.sigma(t_hat) + alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat) + dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat) + + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + if i == 0 or self.exact_henu: + cfg_x = torch.cat([x, x], dim=0) + cfg_t_cur = t_cur.repeat(2) + out = net(cfg_x, cfg_t_cur, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma) + else: + v = v_hat + s = s_hat + x_hat = self.step_fn(x, v, dt, s=s, w=w) + # henu correct + if i < self.num_steps -1: + cfg_x_hat = torch.cat([x_hat, x_hat], dim=0) + cfg_t_hat = t_hat.repeat(2) + out = net(cfg_x_hat, cfg_t_hat, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v_hat = out + s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat) + v = (v + v_hat) / 2 + s = (s + s_hat) / 2 + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/scheduling.py b/src/diffusion/flow_matching/scheduling.py new file mode 100644 index 0000000..a82cd3a --- /dev/null +++ b/src/diffusion/flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/flow_matching/training.py b/src/diffusion/flow_matching/training.py new file mode 100644 index 0000000..55c964d --- /dev/null +++ b/src/diffusion/flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_cos.py b/src/diffusion/flow_matching/training_cos.py new file mode 100644 index 0000000..aff30a7 --- /dev/null +++ b/src/diffusion/flow_matching/training_cos.py @@ -0,0 +1,59 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class COSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + fm_loss = weight*(out - v_t)**2 + cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1) + cos_loss = 1 - cos_sim + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + cos_loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_pyramid.py b/src/diffusion/flow_matching/training_pyramid.py new file mode 100644 index 0000000..be2bd94 --- /dev/null +++ b/src/diffusion/flow_matching/training_pyramid.py @@ -0,0 +1,68 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class PyramidTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + + output_pyramid = [] + def feature_hook(module, input, output): + output_pyramid.extend(output) + handle = net.decoder.register_forward_hook(feature_hook) + net(x_t, t, y) + handle.remove() + + loss = 0.0 + out_dict = dict() + + cur_v_t = v_t + for i in range(len(output_pyramid)): + cur_out = output_pyramid[i] + loss_i = (cur_v_t - cur_out) ** 2 + loss += loss_i.mean() + out_dict["loss_{}".format(i)] = loss_i.mean() + cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False) + out_dict["loss"] = loss + return out_dict + diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py new file mode 100644 index 0000000..e9a6788 --- /dev/null +++ b/src/diffusion/flow_matching/training_repa.py @@ -0,0 +1,142 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/flow_matching/training_repa_mask.py b/src/diffusion/flow_matching/training_repa_mask.py new file mode 100644 index 0000000..f8c4edb --- /dev/null +++ b/src/diffusion/flow_matching/training_repa_mask.py @@ -0,0 +1,152 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + mask_ratio=0.0, + mask_patch_size=2, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.mask_ratio = mask_ratio + self.mask_patch_size = mask_patch_size + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device) + patch_mask = (patch_mask < self.mask_ratio).float() + mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest') + masked_x = x*(1-mask)# + torch.randn_like(x)*(mask) + + x_t = alpha*masked_x + sigma*noise + v_t = dalpha*x + dsigma*noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + v_t_out, x0_out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean()) + mask_loss = mask*weight*(x0_out - x)**2/(mask.mean()) + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + mask_loss=mask_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/pre_integral.py b/src/diffusion/pre_integral.py new file mode 100644 index 0000000..848533a --- /dev/null +++ b/src/diffusion/pre_integral.py @@ -0,0 +1,143 @@ +import torch + +# lagrange interpolation +def lagrange_preint_o1(t1, v1, int_t_start, int_t_end): + ''' + lagrange interpolation of order 1 + Args: + t1: timestepx + v1: value field at t1 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = (int_t_end-int_t_start) + return int1*v1, (int1/int1, ) + +def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end): + ''' + lagrange interpolation of order 2 + Args: + t1: timestepx + t2: timestepy + v1: value field at t1 + v2: value field at t2 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2) + int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2) + int_sum = int1+int2 + return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum) + +def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end): + ''' + lagrange interpolation of order 3 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3) + int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end + int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3) + int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end + int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2) + int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end + int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int_sum = int1+int2+int3 + return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum) + +def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end): + ''' + lagrange interpolation of order 4 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + t4: timestepw + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + v4: value field at t4 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3)*(t1-t4) + int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end + int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3)*(t2-t4) + int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end + int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2)*(t3-t4) + int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end + int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int4_denom = (t4-t1)*(t4-t2)*(t4-t3) + int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end + int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start + int4 = (int4_end - int4_start)/int4_denom + int_sum = int1+int2+int3+int4 + return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum) + + +def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end): + ''' + lagrange interpolation + Args: + order: order of interpolation + pre_vs: value field at pre_ts + pre_ts: timesteps + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + order = min(order, len(pre_vs), len(pre_ts)) + if order == 1: + return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end) + elif order == 2: + return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 3: + return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 4: + return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + else: + raise ValueError('Invalid order') + + +def polynomial_integral(coeffs, int_t_start, int_t_end): + ''' + polynomial integral + Args: + coeffs: coefficients of the polynomial + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + orders = len(coeffs) + int_val = 0 + for o in range(orders): + int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1)) + return int_val + diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py new file mode 100644 index 0000000..fb2e95b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -0,0 +1,112 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + state_refresh_rate: int = 1, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + self.state_refresh_rate = state_refresh_rate + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + state = None + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv.py b/src/diffusion/stateful_flow_matching/bak/training_adv.py new file mode 100644 index 0000000..4792950 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_adv.py @@ -0,0 +1,122 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + adv_feature = [] + def forward_hook(net, input, output): + adv_feature.append(output) + handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma + real_feature = adv_feature.pop() + net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) + fake_feature = adv_feature.pop() + handle.remove() + + + real_score_gan = self.dis_head(real_feature.detach()) + fake_score_gan = self.dis_head(fake_feature.detach()) + fake_score = self.dis_head(fake_feature) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py new file mode 100644 index 0000000..2843c04 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py @@ -0,0 +1,127 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + lpips_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.lpips_weight = lpips_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = (x_t + out * sigma) + + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + with torch.no_grad(): + _, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer) + _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer) + + real_score_gan = self.dis_head(real_features[-1].detach()) + fake_score_gan = self.dis_head(fake_features[-1].detach()) + fake_score = self.dis_head(fake_features[-1]) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py new file mode 100644 index 0000000..849ee4b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py @@ -0,0 +1,159 @@ +import random + +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class MaskREPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + mask_groups=4, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.mask_groups = mask_groups + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')): + mask = torch.zeros(1, length, length, device=device, dtype=torch.bool) + random_seq = torch.randperm(length, device=device) + for i in range(groups): + group_start = (length+groups-1)//groups*i + group_end = (length+groups-1)//groups*(i+1) + group_random_seq = random_seq[group_start:group_end] + y, x = torch.meshgrid(group_random_seq, group_random_seq) + mask[:, y, x] = True + return mask + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + mask_groups = random.randint(1, self.mask_groups) + mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device) + out, _ = net(x_t, t, y, mask=mask) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py new file mode 100644 index 0000000..229680c --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py @@ -0,0 +1,179 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, mul=1000): + t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class BatchNormWithTimeEmbedding(nn.Module): + def __init__(self, num_features): + super().__init__() + # self.bn = nn.BatchNorm2d(num_features, affine=False) + self.bn = nn.GroupNorm(16, num_features, affine=False) + # self.bn = nn.SyncBatchNorm(num_features, affine=False) + self.embedder = TimestepEmbedder(num_features * 2) + # nn.init.zeros_(self.embedder.mlp[-1].weight) + nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01) + nn.init.zeros_(self.embedder.mlp[-1].bias) + + def forward(self, x, t): + embed = self.embedder(t) + embed = embed[:, :, None, None] + gamma, beta = embed.chunk(2, dim=1) + gamma = 1.0 + gamma + normed = self.bn(x) + out = normed * gamma + beta + return out + +class DisBlock(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.conv = nn.Conv2d( + kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0 + ) + self.norm = BatchNormWithTimeEmbedding(hidden_size) + self.act = nn.SiLU() + def forward(self, x, t): + x = self.conv(x) + x = self.norm(x, t) + x = self.act(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, num_blocks, in_channels, hidden_size): + super().__init__() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + DisBlock( + in_channels=in_channels, + hidden_size=hidden_size, + ) + ) + in_channels = hidden_size + self.classifier = nn.Conv2d( + kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1 + ) + def forward(self, feature, t): + B, C, H, W = feature.shape + for block in self.blocks: + feature = block(feature, t) + out = self.classifier(feature).view(B, -1) + out = out.sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_blocks=3, + adv_in_channels=3, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + + self.discriminator = Discriminator( + num_blocks=adv_blocks, + in_channels=adv_in_channels*2, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = x_t + sigma * out + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + real_feature = torch.cat([x_t, x], dim=1) + fake_feature = torch.cat([x_t, pred_x0], dim=1) + + real_score_gan = self.discriminator(real_feature.detach(), t) + fake_score_gan = self.discriminator(fake_feature.detach(), t) + fake_score = self.discriminator(fake_feature, t) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py new file mode 100644 index 0000000..e84e81f --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py @@ -0,0 +1,154 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPAJiTTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + jit_deltas=0.01, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.jit_deltas = jit_deltas + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas + t2 = torch.clip(t2, 0, 1) + alpha = self.scheduler.alpha(t2) + dalpha = self.scheduler.dalpha(t2) + sigma = self.scheduler.sigma(t2) + dsigma = self.scheduler.dsigma(t2) + x_t2 = alpha * x + noise * sigma + v_t2 = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + _, s = net(x_t, t, y, only_s=True) + out, _ = net(x_t2, t2, y, s) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t2)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py new file mode 100644 index 0000000..d7e741d --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py @@ -0,0 +1,90 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class SelfConsistentTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + lpips_encoder_layer=4, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_encoder_layer = lpips_encoder_layer + self.lpips_weight = lpips_weight + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + real_features = [] + def forward_hook(net, input, output): + real_features.append(output) + handles = [] + for i in range(self.lpips_encoder_layer): + handle = net.encoder.blocks[i].register_forward_hook(forward_hook) + handles.append(handle) + + out, _ = net(x_t, t, y) + + for handle in handles: + handle.remove() + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + noise * sigma + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + _, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer) + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py new file mode 100644 index 0000000..580775b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py @@ -0,0 +1,81 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class SelfLPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + lpips_encoder_layer=4, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_encoder_layer = lpips_encoder_layer + self.lpips_weight = lpips_weight + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + noise * sigma + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + with torch.no_grad(): + _, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer) + _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer) + + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/cm_sampling.py b/src/diffusion/stateful_flow_matching/cm_sampling.py new file mode 100644 index 0000000..5254db5 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/cm_sampling.py @@ -0,0 +1,78 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + + +import logging +logger = logging.getLogger(__name__) + +class CMSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + last_step=None, + step_fn=None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.last_step = last_step + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + cfg_t = t_cur.repeat(batch_size*2) + cfg_x = torch.cat([x, x], dim=0) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + + x0 = x + v * (1-t_cur) + alpha_next = self.scheduler.alpha(t_next) + sigma_next = self.scheduler.sigma(t_next) + x = alpha_next * x0 + sigma_next * torch.randn_like(x) + # print(alpha_next, sigma_next) + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/sampling.py b/src/diffusion/stateful_flow_matching/sampling.py new file mode 100644 index 0000000..5fdfdb2 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sampling.py @@ -0,0 +1,103 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/scheduling.py b/src/diffusion/stateful_flow_matching/scheduling.py new file mode 100644 index 0000000..a82cd3a --- /dev/null +++ b/src/diffusion/stateful_flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py new file mode 100644 index 0000000..f372028 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sharing_sampling.py @@ -0,0 +1,149 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + # init recompute + self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate) + self.recompute_timesteps = list(range(self.num_steps)) + + def sharing_dp(self, net, noise, condition, uncondition): + _, C, H, W = noise.shape + B = 8 + template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device) + template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device) + template_uncondition = torch.full((B, ), 1000, device=condition.device) + _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition) + states = torch.stack(state_list) + N, B, L, C = states.shape + states = states.view(N, B*L, C ) + states = states.permute(1, 0, 2) + states = torch.nn.functional.normalize(states, dim=-1) + with torch.autocast(device_type="cuda", dtype=torch.float64): + sim = torch.bmm(states, states.transpose(1, 2)) + sim = torch.mean(sim, dim=0).cpu() + error_map = (1-sim).tolist() + + # init cum-error + for i in range(1, self.num_steps): + for j in range(0, i): + error_map[i][j] = error_map[i-1][j] + error_map[i][j] + + # init dp and force 0 start + C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + for i in range(1, self.num_steps+1): + C[1][i] = error_map[i - 1][0] + P[1][i] = 0 + + # dp state + for step in range(2, self.num_recompute_timesteps+1): + for i in range(step, self.num_steps+1): + min_value = 99999 + min_index = -1 + for j in range(step-1, i): + value = C[step-1][j] + error_map[i-1][j] + if value < min_value: + min_value = value + min_index = j + C[step][i] = min_value + P[step][i] = min_index + + # trace back + timesteps = [self.num_steps,] + for i in range(self.num_recompute_timesteps, 0, -1): + idx = timesteps[-1] + timesteps.append(P[i][idx]) + timesteps.reverse() + + print("recompute timesteps solved by DP: ", timesteps) + return timesteps[:-1] + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + pooled_state_list = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i in self.recompute_timesteps: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=0.0, w=0.0) + else: + x = self.last_step_fn(x, v, dt, s=0.0, w=0.0) + pooled_state_list.append(state) + return x, pooled_state_list + + def __call__(self, net, noise, condition, uncondition): + if len(self.recompute_timesteps) != self.num_recompute_timesteps: + self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition) + denoised, _ = self._impl_sampling(net, noise, condition, uncondition) + return denoised \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training.py b/src/diffusion/stateful_flow_matching/training.py new file mode 100644 index 0000000..4c49e1e --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_adv.py b/src/diffusion/stateful_flow_matching/training_adv.py new file mode 100644 index 0000000..4792950 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_adv.py @@ -0,0 +1,122 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + adv_feature = [] + def forward_hook(net, input, output): + adv_feature.append(output) + handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma + real_feature = adv_feature.pop() + net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) + fake_feature = adv_feature.pop() + handle.remove() + + + real_score_gan = self.dis_head(real_feature.detach()) + fake_score_gan = self.dis_head(fake_feature.detach()) + fake_score = self.dis_head(fake_feature) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/training_distill_dino.py b/src/diffusion/stateful_flow_matching/training_distill_dino.py new file mode 100644 index 0000000..c6a2937 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_distill_dino.py @@ -0,0 +1,141 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class DistillDINOTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + + x_t = alpha * x + noise * sigma + + _, s = net(x_t, t, y) + src_feature = self.proj(s) + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + out = dict( + cos_loss=cos_loss.mean(), + loss=cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/training_lpips.py b/src/diffusion/stateful_flow_matching/training_lpips.py new file mode 100644 index 0000000..a3cd2a2 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_lpips.py @@ -0,0 +1,71 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class LPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + # self.lpips = torch.compile(self.lpips) + no_grad(self.lpips) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out*sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0*0.5, target_x0*0.5) + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + lpips.mean()*self.lpips_weight, + ) + return out + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py new file mode 100644 index 0000000..e0233ea --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py @@ -0,0 +1,74 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class LPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = False + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + # self.lpips = torch.compile(self.lpips) + no_grad(self.lpips) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + + fm_weight = t*(1-t)**2/0.25 + lpips_weight = t + + loss = (out - v_t)**2 * fm_weight[:, None, None, None] + + pred_x0 = (x_t + out*sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None] + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + lpips.mean()*self.lpips_weight, + ) + return out + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_repa.py b/src/diffusion/stateful_flow_matching/training_repa.py new file mode 100644 index 0000000..a5a28db --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_repa.py @@ -0,0 +1,157 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + + if getattr(net, "blocks", None) is not None: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/training_repa_lpips.py b/src/diffusion/stateful_flow_matching/training_repa_lpips.py new file mode 100644 index 0000000..5a11207 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_repa_lpips.py @@ -0,0 +1,170 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPALPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + lpips_weight=1.0, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + no_grad(self.lpips) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + if getattr(net, "blocks", None) is not None: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5) + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/lightning_data.py b/src/lightning_data.py new file mode 100644 index 0000000..9f75a42 --- /dev/null +++ b/src/lightning_data.py @@ -0,0 +1,162 @@ +from typing import Any +import torch +import copy +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS +from torch.utils.data import DataLoader +from src.data.dataset.randn import RandomNDataset +from src.data.var_training import VARTransformEngine + +def collate_fn(batch): + new_batch = copy.deepcopy(batch) + new_batch = list(zip(*new_batch)) + for i in range(len(new_batch)): + if isinstance(new_batch[i][0], torch.Tensor): + try: + new_batch[i] = torch.stack(new_batch[i], dim=0) + except: + print("Warning: could not stack tensors") + return new_batch + +class DataModule(pl.LightningDataModule): + def __init__(self, + train_root, + test_nature_root, + test_gen_root, + train_image_size=64, + train_batch_size=64, + train_num_workers=8, + var_transform_engine: VARTransformEngine = None, + train_prefetch_factor=2, + train_dataset: str = None, + eval_batch_size=32, + eval_num_workers=4, + eval_max_num_instances=50000, + pred_batch_size=32, + pred_num_workers=4, + pred_seeds:str=None, + pred_selected_classes=None, + num_classes=1000, + latent_shape=(4,64,64), + ): + super().__init__() + pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None + + self.train_root = train_root + self.train_image_size = train_image_size + self.train_dataset = train_dataset + # stupid data_convert override, just to make nebular happy + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_prefetch_factor = train_prefetch_factor + + self.test_nature_root = test_nature_root + self.test_gen_root = test_gen_root + self.eval_max_num_instances = eval_max_num_instances + self.pred_seeds = pred_seeds + self.num_classes = num_classes + self.latent_shape = latent_shape + + self.eval_batch_size = eval_batch_size + self.pred_batch_size = pred_batch_size + + self.pred_num_workers = pred_num_workers + self.eval_num_workers = eval_num_workers + + self.pred_selected_classes = pred_selected_classes + + self._train_dataloader = None + self.var_transform_engine = var_transform_engine + + def setup(self, stage: str) -> None: + if stage == "fit": + assert self.train_dataset is not None + if self.train_dataset == "pix_imagenet64": + from src.data.dataset.imagenet import PixImageNet64 + self.train_dataset = PixImageNet64( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet128": + from src.data.dataset.imagenet import PixImageNet128 + self.train_dataset = PixImageNet128( + root=self.train_root, + ) + elif self.train_dataset == "imagenet256": + from src.data.dataset.imagenet import ImageNet256 + self.train_dataset = ImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet256": + from src.data.dataset.imagenet import PixImageNet256 + self.train_dataset = PixImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "imagenet512": + from src.data.dataset.imagenet import ImageNet512 + self.train_dataset = ImageNet512( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet512": + from src.data.dataset.imagenet import PixImageNet512 + self.train_dataset = PixImageNet512( + root=self.train_root, + ) + else: + raise NotImplementedError("no such dataset") + + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + if self.var_transform_engine and self.trainer.training: + batch = self.var_transform_engine(batch) + return batch + + def train_dataloader(self) -> TRAIN_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) + self._train_dataloader = DataLoader( + self.train_dataset, + self.train_batch_size, + timeout=6000, + num_workers=self.train_num_workers, + prefetch_factor=self.train_prefetch_factor, + sampler=sampler, + collate_fn=collate_fn, + ) + return self._train_dataloader + + def val_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.eval_dataset = RandomNDataset( + latent_shape=self.latent_shape, + num_classes=self.num_classes, + max_num_instances=self.eval_max_num_instances, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.eval_dataset, self.eval_batch_size, + num_workers=self.eval_num_workers, + prefetch_factor=2, + collate_fn=collate_fn, + sampler=sampler + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.pred_dataset = RandomNDataset( + seeds= self.pred_seeds, + max_num_instances=50000, + num_classes=self.num_classes, + selected_classes=self.pred_selected_classes, + latent_shape=self.latent_shape, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, + num_workers=self.pred_num_workers, + prefetch_factor=4, + collate_fn=collate_fn, + sampler=sampler + ) diff --git a/src/lightning_model.py b/src/lightning_model.py new file mode 100644 index 0000000..4602e82 --- /dev/null +++ b/src/lightning_model.py @@ -0,0 +1,123 @@ +from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict +import os.path +import copy +import torch +import torch.nn as nn +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from lightning.pytorch.callbacks import Callback + + +from src.models.vae import BaseVAE, fp2uint8 +from src.models.conditioner import BaseConditioner +from src.utils.model_loader import ModelLoader +from src.callbacks.simple_ema import SimpleEMA +from src.diffusion.base.sampling import BaseSampler +from src.diffusion.base.training import BaseTrainer +from src.utils.no_grad import no_grad, filter_nograd_tensors +from src.utils.copy import copy_params + +EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] +OptimizerCallable = Callable[[Iterable], Optimizer] +LRSchedulerCallable = Callable[[Optimizer], LRScheduler] + + +class LightningModel(pl.LightningModule): + def __init__(self, + vae: BaseVAE, + conditioner: BaseConditioner, + denoiser: nn.Module, + diffusion_trainer: BaseTrainer, + diffusion_sampler: BaseSampler, + ema_tracker: Optional[EMACallable] = None, + optimizer: OptimizerCallable = None, + lr_scheduler: LRSchedulerCallable = None, + ): + super().__init__() + self.vae = vae + self.conditioner = conditioner + self.denoiser = denoiser + self.ema_denoiser = copy.deepcopy(self.denoiser) + self.diffusion_sampler = diffusion_sampler + self.diffusion_trainer = diffusion_trainer + self.ema_tracker = ema_tracker + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # self.model_loader = ModelLoader() + + self._strict_loading = False + + def configure_model(self) -> None: + self.trainer.strategy.barrier() + # self.denoiser = self.model_loader.load(self.denoiser) + copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) + + # self.denoiser = torch.compile(self.denoiser) + # disable grad for conditioner and vae + no_grad(self.conditioner) + no_grad(self.vae) + no_grad(self.diffusion_sampler) + no_grad(self.ema_denoiser) + + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) + return [ema_tracker] + + def configure_optimizers(self) -> OptimizerLRScheduler: + params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) + params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) + optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) + if self.lr_scheduler is None: + return dict( + optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler(optimizer) + return dict( + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + def training_step(self, batch, batch_idx): + raw_images, x, y = batch + with torch.no_grad(): + x = self.vae.encode(x) + condition, uncondition = self.conditioner(y) + loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition) + self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) + return loss["loss"] + + def predict_step(self, batch, batch_idx): + xT, y, metadata = batch + with torch.no_grad(): + condition, uncondition = self.conditioner(y) + # Sample images: + samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) + samples = self.vae.decode(samples) + # fp32 -1,1 -> uint8 0,255 + samples = fp2uint8(samples) + return samples + + def validation_step(self, batch, batch_idx): + samples = self.predict_step(batch, batch_idx) + return samples + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = {} + self._save_to_state_dict(destination, prefix, keep_vars) + self.denoiser.state_dict( + destination=destination, + prefix=prefix+"denoiser.", + keep_vars=keep_vars) + self.ema_denoiser.state_dict( + destination=destination, + prefix=prefix+"ema_denoiser.", + keep_vars=keep_vars) + self.diffusion_trainer.state_dict( + destination=destination, + prefix=prefix+"diffusion_trainer.", + keep_vars=keep_vars) + return destination \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/conditioner.py b/src/models/conditioner.py new file mode 100644 index 0000000..a68fad3 --- /dev/null +++ b/src/models/conditioner.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class BaseConditioner(nn.Module): + def __init__(self): + super(BaseConditioner, self).__init__() + + def _impl_condition(self, y): + ... + def _impl_uncondition(self, y): + ... + def __call__(self, y): + condition = self._impl_condition(y) + uncondition = self._impl_uncondition(y) + return condition, uncondition + +class LabelConditioner(BaseConditioner): + def __init__(self, null_class): + super().__init__() + self.null_condition = null_class + + def _impl_condition(self, y): + return torch.tensor(y).long().cuda() + + def _impl_uncondition(self, y): + return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() \ No newline at end of file diff --git a/src/models/denoiser/__init__.py b/src/models/denoiser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py new file mode 100644 index 0000000..3581446 --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py @@ -0,0 +1,383 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + x = torch.cat([x, s], dim=-1) + x = self.x_embedder(x) + for i in range(self.num_blocks): + x = self.blocks[i](x, c, pos, None) + x = self.final_layer(x, c) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, + stride=self.patch_size) + return x + + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder) + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None): + if s is None: + with torch.no_grad(): + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py new file mode 100644 index 0000000..733ce4a --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py @@ -0,0 +1,447 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + self.embed_proj = nn.Linear(hidden_dim, dim) + + def forward(self, x, c): + c = self.embed_proj(c)[:, :, None, None] + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = x * c + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + y = self.y_embedder(y).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x, c) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x, c) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder) + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py new file mode 100644 index 0000000..6e9adbc --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py @@ -0,0 +1,448 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + self.embed_proj = nn.Linear(hidden_dim, dim) + + def forward(self, x, c): + c = self.embed_proj(c)[:, :, None, None] + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = x * c + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + y = self.y_embedder(y).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x, c) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x, c) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder, "encoder.") + ModelLoader().load(decoder, "decoder.") + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py new file mode 100644 index 0000000..537078a --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py @@ -0,0 +1,464 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + # Zero-out adaLN modulation layers in SiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for block in self.down_res_blocks: + if isinstance(block, ResBlock): + nn.init.constant_(block.conv1.weight, 0) + nn.init.constant_(block.conv1.bias, 0) + nn.init.constant_(block.norm1.weight, 0) + nn.init.constant_(block.norm2.weight, 0) + nn.init.constant_(block.conv2.weight, 0) + nn.init.constant_(block.conv2.bias, 0) + + for block in self.up_res_blocks: + if isinstance(block, ResBlock): + nn.init.constant_(block.conv1.weight, 0) + nn.init.constant_(block.conv1.bias, 0) + nn.init.constant_(block.norm1.weight, 0) + nn.init.constant_(block.norm2.weight, 0) + nn.init.constant_(block.conv2.weight, 0) + nn.init.constant_(block.conv2.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder, "encoder.") + ModelLoader().load(decoder, "decoder.") + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/condit_dit.py b/src/models/denoiser/condit_dit.py new file mode 100644 index 0000000..48d6b0e --- /dev/null +++ b/src/models/denoiser/condit_dit.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import math + +from numba.cuda.cudadrv.devicearray import lru_cache +from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + +from torch.nn.attention import SDPBackend, sdpa_kernel + +flex_attention = torch.compile(flex_attention) + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = False, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = self.norm(x) + x = modulate(x, shift, scale) + x = self.linear(x) + return x + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approximate="tanh") + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16): + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + x_pos = x_pos.reshape(-1) + y_pos = y_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1) + return freqs_cis + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + # import pdb; pdb.set_trace() + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + # x = flex_attention(q, k, v, block_mask=mask) + # with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v, mask) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class ConDiT(nn.Module): + def __init__( + self, + in_channels=4, + out_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2) + self.num_cond_blocks = num_cond_blocks + + + self.weight_path = weight_path + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + + @lru_cache + def fetch_pos(self, height, width, device): + pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...] + return pos + + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H, W, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + + if s is None: + # semantic encoder + s = self.s_embedder(x) + pos + c = nn.functional.silu(t + y) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s diff --git a/src/models/denoiser/flatten_condit_catdit_fixt.py b/src/models/denoiser/flatten_condit_catdit_fixt.py new file mode 100644 index 0000000..22a0fd5 --- /dev/null +++ b/src/models/denoiser/flatten_condit_catdit_fixt.py @@ -0,0 +1,314 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + # s = nn.functional.silu(t + s) + s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) + x = torch.cat((x, s), dim=-1) + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, c, pos, None) + x = self.final_layer(x, c) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_conv_fixt.py b/src/models/denoiser/flatten_condit_conv_fixt.py new file mode 100644 index 0000000..219db4c --- /dev/null +++ b/src/models/denoiser/flatten_condit_conv_fixt.py @@ -0,0 +1,340 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenConvBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3): + super().__init__() + self.hidden_size = hidden_size + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + attn_x = modulate(self.norm1(x), shift_msa, scale_msa) + attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() + attn_x = self.attn(attn_x) + attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) + x = x + gate_msa * attn_x + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + kernel_size=3, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_convnext_fixt.py b/src/models/denoiser/flatten_condit_convnext_fixt.py new file mode 100644 index 0000000..cf9c214 --- /dev/null +++ b/src/models/denoiser/flatten_condit_convnext_fixt.py @@ -0,0 +1,339 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenConvBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.hidden_size = hidden_size + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + attn_x = modulate(self.norm1(x), shift_msa, scale_msa) + attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() + attn_x = self.attn(attn_x) + attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) + x = x + gate_msa * attn_x + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_dit_fixt.py b/src/models/denoiser/flatten_condit_dit_fixt.py new file mode 100644 index 0000000..15557f3 --- /dev/null +++ b/src/models/denoiser/flatten_condit_dit_fixt.py @@ -0,0 +1,313 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_dit_norm_fixt.py b/src/models/denoiser/flatten_condit_dit_norm_fixt.py new file mode 100644 index 0000000..28034e3 --- /dev/null +++ b/src/models/denoiser/flatten_condit_dit_norm_fixt.py @@ -0,0 +1,314 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py new file mode 100644 index 0000000..9a5e4fd --- /dev/null +++ b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py @@ -0,0 +1,429 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + # s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + s = torch.nn.functional.silu(t + s) + x = self.x_embedder(x) + for i in range(self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, + stride=self.patch_size) + return x + + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + joint_training=False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + ModelLoader().load(encoder) + if not joint_training: + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + self.joint_training = joint_training + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiTScalingEncoder(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + no_grad(self.decoder) + + if self.encoder.weight_path: + weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu')) + if self.encoder.load_ema: + prefix = "ema_denoiser." + else: + prefix = "denoiser." + for k, v in self.encoder.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + print(f"Failed to copy {prefix+k} to denoiser weight") + + if self.decoder.weight_path: + weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu')) + if self.decoder.load_ema: + prefix = "ema_denoiser." + else: + prefix = "denoiser." + for k, v in self.decoder.state_dict().items(): + if "blocks." in k: + blockid = int(k.split("blocks.")[-1][0]) + k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}") + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + print(f"Failed to copy {prefix+k} to denoiser weight") + self.decoder = decoder.to(torch.bfloat16) + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/flatten_condit_mlp_fixt.py b/src/models/denoiser/flatten_condit_mlp_fixt.py new file mode 100644 index 0000000..40735e4 --- /dev/null +++ b/src/models/denoiser/flatten_condit_mlp_fixt.py @@ -0,0 +1,334 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenMLPBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = FeedForward(hidden_size, mlp_hidden_dim) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py new file mode 100644 index 0000000..bcf3315 --- /dev/null +++ b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py @@ -0,0 +1,321 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device) + s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2) + s = self.s_embedder(s) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos_s, mask) + s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size) + s = torch.permute(s, (0, 3, 1, 2)) + s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False) + s = torch.permute(s, (0, 2, 3, 1)) + s = s.view(B, -1, self.hidden_size) + s = nn.functional.silu(t + s) + + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos_x, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt.py b/src/models/denoiser/flatten_dit_fixt.py new file mode 100644 index 0000000..9412d6e --- /dev/null +++ b/src/models/denoiser/flatten_dit_fixt.py @@ -0,0 +1,306 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device, dtype): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device, dtype) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, masks=None): + if masks is None: + masks = [None, ]*self.num_blocks + if isinstance(masks, torch.Tensor): + masks = masks.unbind(0) + if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: + masks = masks + [None]*(self.num_blocks-len(masks)) + + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + B, L, C = x.shape + t = self.t_embedder(t.view(-1)).view(B, -1, C) + y = self.y_embedder(y).view(B, 1, C) + condition = nn.functional.silu(t + y) + for i, block in enumerate(self.blocks): + x = block(x, condition, pos, masks[i]) + x = self.final_layer(x, condition) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt_xvout.py b/src/models/denoiser/flatten_dit_fixt_xvout.py new file mode 100644 index 0000000..4df3393 --- /dev/null +++ b/src/models/denoiser/flatten_dit_fixt_xvout.py @@ -0,0 +1,311 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device, dtype): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device, dtype) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, masks=None): + if masks is None: + masks = [None, ]*self.num_blocks + if isinstance(masks, torch.Tensor): + masks = masks.unbind(0) + if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: + masks = masks + [None]*(self.num_blocks-len(masks)) + + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + B, L, C = x.shape + t = self.t_embedder(t.view(-1)).view(B, -1, C) + y = self.y_embedder(y).view(B, 1, C) + condition = nn.functional.silu(t + y) + for i, block in enumerate(self.blocks): + x = block(x, condition, pos, masks[i]) + x = self.final_layer(x, condition) + x0, v = x.chunk(2, dim=-1) + x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + if self.training: + return v, x0 + else: + return v \ No newline at end of file diff --git a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py new file mode 100644 index 0000000..4e570b0 --- /dev/null +++ b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py @@ -0,0 +1,308 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.x_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flowdcn.py b/src/models/denoiser/flowdcn.py new file mode 100644 index 0000000..92e2237 --- /dev/null +++ b/src/models/denoiser/flowdcn.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from src.models.denoiser.base_model import BaseModel +from src.ops.triton_kernels.function import DCNFunction + +def modulate(x, shift, scale): + return x * (1 + scale[:, None, None]) + shift[:, None, None] + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + b, h, w, c = x.shape + x = x.view(b, h*w, c) + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + x = x.view(b, h, w, c) + return x + + +class MultiScaleDCN(nn.Module): + def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): + super().__init__() + self.in_channels = in_channels + self.groups = groups + self.channels = channels + self.kernels = kernels + self.v = nn.Linear(in_channels, groups * channels, bias=True) + self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) + self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) + self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) + self.out = nn.Linear(groups * channels, in_channels) + self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) + self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) + self.max_scale = 6 + self._init_weights() + def _init_weights(self): + zeros_(self.qk_deformables.weight.data) + zeros_(self.qk_scales.weight.data) + zeros_(self.qk_deformables.bias.data) + zeros_(self.qk_weights.weight.data) + zeros_(self.v.bias.data) + zeros_(self.out.bias.data) + num_prior = int(self.kernels ** 0.5) + dx = torch.linspace(-1, 1, num_prior, device="cuda") + dy = torch.linspace(-1, 1, num_prior, device="cuda") + dxy = torch.meshgrid([dx, dy], indexing="xy") + dxy = torch.stack(dxy, dim=-1) + dxy = dxy.view(-1, 2) + self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy + for i in range(self.groups): + scale = (i+1)/self.groups - 0.0001 + inv_scale = math.log((scale)/(1-scale)) + self.deformables_scale.data[..., i, :, :] = inv_scale + def forward(self, x): + B, H, W, _ = x.shape + v = self.v(x).view(B, H, W, self.groups, self.channels) + deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) + scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale + deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale + weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) + out = DCNFunction.apply(v, deformables, weights) + out = out.view(B, H, W, -1) + out = self.out(out) + return out + +class FlowDCNBlock(nn.Module): + def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + + + +class FlowDCN(BaseModel): + def __init__(self, deformable_biass=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.blocks = nn.ModuleList([ + FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks) + ]) + self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True) + self.initialize_weights() + + def forward(self, x, t, y): + batch_size, _, height, width = x.shape[0] + x = self.x_embedder(x) # (N, D, h, w) + x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1) + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + B, L, C = x.shape + x = x.view(B, height//self.patch_size, width//self.patch_size, C) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = x.view(B, L, C) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size) + if self.learn_sigma: + x, _ = torch.split(x, self.out_channels // 2, dim=1) + return x \ No newline at end of file diff --git a/src/models/encoder.py b/src/models/encoder.py new file mode 100644 index 0000000..8b7f96a --- /dev/null +++ b/src/models/encoder.py @@ -0,0 +1,132 @@ +import torch +import copy +import os +import timm +import transformers +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from torchvision.transforms import Normalize + +class RandViT(nn.Module): + def __init__(self, model_id, weight_path:str=None): + super(RandViT, self).__init__() + self.encoder = timm.create_model( + model_id, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class DINO(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINO, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([ 0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([ 1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class CLIP(nn.Module): + def __init__(self, model_id, weight_path:str): + super(CLIP, self).__init__() + self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path) + self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size + self.shifts = nn.Parameter(torch.tensor([0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + + +class DINOv2(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = transformers.Dinov2Model.from_pretrained(weight_path) + self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature \ No newline at end of file diff --git a/src/models/vae.py b/src/models/vae.py new file mode 100644 index 0000000..c47b087 --- /dev/null +++ b/src/models/vae.py @@ -0,0 +1,81 @@ +import torch +import subprocess +import lightning.pytorch as pl + +import logging + + +logger = logging.getLogger(__name__) +def class_fn_from_str(class_str): + class_module, from_class = class_str.rsplit(".", 1) + class_module = __import__(class_module, fromlist=[from_class]) + return getattr(class_module, from_class) + + +class BaseVAE(torch.nn.Module): + def __init__(self, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + + def encode(self, x): + return x/self.scale+self.shift + + def decode(self, x): + return (x-self.shift)*self.scale + + +# very bad bugs with nearest sampling +class DownSampleVAE(BaseVAE): + def __init__(self, down_ratio, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + self.down_ratio = down_ratio + + def encode(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False) + return x/self.scale+self.shift + + def decode(self, x): + x = (x-self.shift)*self.scale + x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False) + return x + + + +class LatentVAE(BaseVAE): + def __init__(self, precompute=False, weight_path:str=None): + super().__init__() + self.precompute = precompute + self.model = None + self.weight_path = weight_path + + from diffusers.models import AutoencoderKL + setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path)) + self.scaling_factor = self.model.config.scaling_factor + + @torch.no_grad() + def encode(self, x): + assert self.model is not None + if self.precompute: + return x.mul_(self.scaling_factor) + return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor) + + @torch.no_grad() + def decode(self, x): + assert self.model is not None + return self.model.decode(x.div_(self.scaling_factor)).sample + + +def uint82fp(x): + x = x.to(torch.float32) + x = (x - 127.5) / 127.5 + return x + +def fp2uint8(x): + x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8) + return x + diff --git a/src/ops/cuda_kernels/backward.cu b/src/ops/cuda_kernels/backward.cu new file mode 100644 index 0000000..2e85d86 --- /dev/null +++ b/src/ops/cuda_kernels/backward.cu @@ -0,0 +1,346 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace cg = cooperative_groups; + +template +__device__ __always_inline int toInt(scalar_t val); + +template<> +__device__ __always_inline int toInt(float val){ + return static_cast(val); +} +template<> +__device__ __always_inline int toInt(half val){ + return __half2int_rz(val); +} + +template +__device__ __always_inline scalar_t fromInt(int val); + +template<> +__device__ __always_inline float fromInt(int val){ + return static_cast(val); +} + +template<> +__device__ __always_inline half fromInt(int val){ + return __int2half_rz(val); +} + +template +__device__ __always_inline scalar_t constVal(float val); + +template<> +__device__ __always_inline float constVal(float val) { + return (float)val; +} + +template<> +__device__ __always_inline half constVal(float val) { + return __float2half(val); // Using float to half conversion +} +template<> +__device__ __always_inline nv_bfloat16 constVal(float val){ + return __float2bfloat16(val); +} + + + + + +// B, H, W, C, BLOCK_DIM must be multiple of C +template +__global__ void dcn_backward_pipeline_kernel( + const int H, + const int W, + const int G, + const int K, + const int C, + scalar_t* ptr_values, + scalar_t* ptr_deformables, + scalar_t* ptr_weights, + scalar_t* ptr_grad_out, + scalar_t* ptr_grad_values, + scalar_t* ptr_grad_deformables, + scalar_t* ptr_grad_weights +) { + auto block = cg::this_thread_block(); + auto self_thread = cg::this_thread(); + auto tile_threads = cg::tiled_partition(block); + int local_thread_id = block.thread_rank(); + int local_tile_id = tile_threads.meta_group_rank(); + int num_local_tiles = tile_threads.meta_group_size(); + int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id; + + extern __shared__ int shm[]; + auto GradBuffer = reinterpret_cast(shm); + scalar_t* Buffer = reinterpret_cast(shm) + num_local_tiles*C; + if(global_tile_id >= H*W*G) return; + + int bid = block.group_index().y; + int gid = global_tile_id % G; + int wid = global_tile_id / G % W; + int hid = global_tile_id / G / W; + int globale_offset = bid*H*W*G*C + global_tile_id*C; + cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C); + + int shared_offset[pipeline_stages]; + for (int s = 0; s < pipeline_stages; ++s) { + shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4); + } + + auto pipeline = cuda::make_pipeline(); + const int num_tiles_per_thread = C/TILE_C/TILE_THREADS; + + for(int k=0; k(wid); + y = ptr_deformables[offset*2 + 1] + fromInt(hid); +// x = fromInt(wid); +// y = fromInt(hid); + weight = ptr_weights[offset]; + } + tile_threads.sync(); + x = tile_threads.shfl(x, 0); + y = tile_threads.shfl(y, 0); + weight = tile_threads.shfl(weight, 0); + + int floor_x = toInt(x); + int floor_y = toInt(y); + int ceil_x = floor_x + 1; + int ceil_y = floor_y + 1; + + + scalar_t dodx = constVal(0.0f); + scalar_t dody = constVal(0.0f); + scalar_t dodw = constVal(0.0f); + + int start_c = tile_threads.thread_rank() * (C / TILE_THREADS); + + bool tl_flag = (floor_x >=0) and (floor_x =0) and (floor_y=0) and (ceil_x =0) and (floor_y=0) and (floor_x =0) and (ceil_y=0) and (ceil_x =0) and (ceil_y(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; + dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1]; + dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; + { + vec2_t vtl_di; + vtl_di.x = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; + vtl_di.y = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1]; + atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di); + } + } + + + if(tr_flag){ + // tr + dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1]; + dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vtr_di; + vtr_di.x = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; + vtr_di.y = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di); + } + } + + if(bl_flag){ + // bl + dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vbl_di; + vbl_di.x = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; + vbl_di.y = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di); + } + } + + + if(br_flag){ + // tr + dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vbr_di; + vbr_di.x = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; + vbr_di.y = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di); + } + } + } + pipeline.consumer_release(); + } + for (int i = TILE_THREADS>>1; i > 0; i/=2) { + dodx = dodx + tile_threads.shfl_down(dodx, i); + dody = dody + tile_threads.shfl_down(dody, i); + dodw = dodw + tile_threads.shfl_down(dodw, i); + } + if (tile_threads.thread_rank() == 0) { + cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline); + cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline); + cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline); + } + } +} + + +using namespace torch; +template +void backward(const int B, + const int H, + const int W, + const int G, + const int K, + const int C, + torch::Tensor values, + torch::Tensor deformables, + torch::Tensor weights, + torch::Tensor grad_out, + torch::Tensor grad_values, + torch::Tensor grad_deformables, + torch::Tensor grad_weights +) { + int num_local_tiles =(THREADS/TILE_THREADS); + int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(num_global_tiles, B); + + int deformable_shm_size = 0; + int grad_out_shm_size = num_local_tiles*C; + int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS); + + int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size; +// printf("shm_size: %d\n", shm_size/512); +// printf("pipeline_size: %d\n", pipeline_shm_size/512); +// printf("grad_out_size: %d\n", grad_out_shm_size/512); + + + switch (values.type().scalarType()) { + case at::ScalarType::Half: + return dcn_backward_pipeline_kernel<<>>( + H, W, G, K, C, + reinterpret_cast(values.data_ptr()), + reinterpret_cast(deformables.data_ptr()), + reinterpret_cast(weights.data_ptr()), + reinterpret_cast(grad_out.data_ptr()), + reinterpret_cast(grad_values.data_ptr()), + reinterpret_cast(grad_deformables.data_ptr()), + reinterpret_cast(grad_weights.data_ptr()) + ); +// case at::ScalarType::BFloat16: +// return dcn_backward_pipeline_kernel<<>>( +// H, W, G, K, C, +// reinterpret_cast(values.data_ptr()), +// reinterpret_cast(deformables.data_ptr()), +// reinterpret_cast(weights.data_ptr()), +// reinterpret_cast(grad_out.data_ptr()), +// reinterpret_cast(grad_values.data_ptr()), +// reinterpret_cast(grad_deformables.data_ptr()), +// reinterpret_cast(grad_weights.data_ptr()) +// ); + default: + printf("running error"); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, ""); + m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, ""); + m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, ""); + m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, ""); + m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, ""); + m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, ""); + m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, ""); + m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, ""); + m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, ""); + m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, ""); + m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, ""); + m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, ""); + m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, ""); + m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, ""); + m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, ""); +// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, ""); +// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, ""); +// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, ""); + + m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, ""); + m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, ""); + m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, ""); + m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, ""); +} diff --git a/src/ops/cuda_kernels/bak_forward.cu b/src/ops/cuda_kernels/bak_forward.cu new file mode 100644 index 0000000..00569f8 --- /dev/null +++ b/src/ops/cuda_kernels/bak_forward.cu @@ -0,0 +1,289 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +template +__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ +#pragma unroll + for(int i=0; i +__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ +#pragma unroll + for(int i=0; i +__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM + // __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + // loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + // loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); +#pragma unroll + for(int j=0; j +__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM + __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM/transfer_length; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + + } + + __syncthreads(); + +#pragma unroll + for(int i=0; i +void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_16<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::BFloat16: + return dcn_forward_kernel_16<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::Float: + return dcn_forward_kernel<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + + +// PyBind11 bindings +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); +//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); +} diff --git a/src/ops/cuda_kernels/forward.cu b/src/ops/cuda_kernels/forward.cu new file mode 100644 index 0000000..ac18308 --- /dev/null +++ b/src/ops/cuda_kernels/forward.cu @@ -0,0 +1,309 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +template +__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ +#pragma unroll + for(int i=0; i +__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ +#pragma unroll + for(int i=0; i +__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + extern __shared__ int shm[]; + math_t* math_buffer = reinterpret_cast(shm); + + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); +#pragma unroll + for(int j=0; j +__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + extern __shared__ int shm[]; + math_t* math_buffer = reinterpret_cast(shm); + scalar_t* io_buffer = reinterpret_cast(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t); + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM/transfer_length; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + + } + + __syncthreads(); + +#pragma unroll + for(int i=0; i +void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half); + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_register<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::Float: + return dcn_forward_kernel_register<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + +template +void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half); + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_pipeline<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::BFloat16: + return dcn_forward_kernel_pipeline<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + +// PyBind11 bindings +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); +//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); +} diff --git a/src/ops/cuda_kernels/forward.py b/src/ops/cuda_kernels/forward.py new file mode 100644 index 0000000..4ea9c5e --- /dev/null +++ b/src/ops/cuda_kernels/forward.py @@ -0,0 +1,95 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def forward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_channels_per_group + C: tl.constexpr, # num_groups + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] + weights_ptr, # weights [B, H, W, G, K] + out_ptr, # out [B, H, W, G, C] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + + for block_base in tl.static_range(0, C, BLOCK_SIZE): + buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + for k in tl.static_range(K): + deformable_offset = (common_offset * K + k) * 2 + + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) + + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + + + + tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input = tl_block_input * tl_weight + + # load top right + tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input = tr_block_input * tr_weight + # load bottom left + bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input = bl_block_input * bl_weight + # load bottom right + br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input = br_block_input * br_weight + + # sampled + sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input + + weighted_sampled_input = sampled_input * weight + buffer = buffer + weighted_sampled_input + # store to out_ptr + tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) + diff --git a/src/ops/cuda_kernels/function.py b/src/ops/cuda_kernels/function.py new file mode 100644 index 0000000..9d4bfad --- /dev/null +++ b/src/ops/cuda_kernels/function.py @@ -0,0 +1,126 @@ +import time +import dcn_cuda_backward +import dcn_cuda_forward + +import math +import torch +from typing import Any +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_fwd, custom_bwd +from .forward import forward_kernel + + +class DCNFunction(Function): + BP_FUNCS = [ + dcn_cuda_backward.backward_p1_c2_tile16_thread128, + dcn_cuda_backward.backward_p1_c4_tile16_thread128, + dcn_cuda_backward.backward_p2_c2_tile16_thread128, + dcn_cuda_backward.backward_p1_c2_tile16_thread256, + dcn_cuda_backward.backward_p1_c4_tile16_thread256, + dcn_cuda_backward.backward_p2_c2_tile16_thread256, + dcn_cuda_backward.backward_p1_c2_tile16_thread384, + dcn_cuda_backward.backward_p1_c4_tile16_thread384, + dcn_cuda_backward.backward_p2_c2_tile16_thread384, + dcn_cuda_backward.backward_p1_c2_tile16_thread512, + dcn_cuda_backward.backward_p1_c4_tile16_thread512, + dcn_cuda_backward.backward_p2_c2_tile16_thread512, + dcn_cuda_backward.backward_p1_c2_tile16_thread768, + dcn_cuda_backward.backward_p1_c4_tile16_thread768, + dcn_cuda_backward.backward_p2_c2_tile16_thread768, + dcn_cuda_backward.backward_p1_c2_tile32_thread128, + dcn_cuda_backward.backward_p1_c2_tile32_thread256, + dcn_cuda_backward.backward_p1_c2_tile32_thread384, + dcn_cuda_backward.backward_p1_c2_tile32_thread512, + ] + FW_FUNCS = [ + dcn_cuda_forward.dcn_forward_l256_c4, + dcn_cuda_forward.dcn_forward_l256_c8, + dcn_cuda_forward.dcn_forward_l256_c16, + ] + BP_TABLES = dict() + FW_TABLES = dict() + + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, values, deformables, weights) -> Any: + B, H, W, G, C = values.shape + func = DCNFunction.find_fw_funcs(values, deformables, weights) + out = torch.zeros_like(values) + func(B, G, C, H, W, values, deformables, weights, out) + return out + + @staticmethod + def find_fw_funcs(values, deformables, weights): + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + hash_value = 10000 * B + 100 * H + W + 1000 * G + if hash_value in DCNFunction.FW_TABLES.keys(): + return DCNFunction.FW_TABLES[hash_value] + print("missing") + candicate_func = None + min_t = 999.0 + outs = torch.zeros_like(values) + for func in DCNFunction.FW_FUNCS: + t = [] + for i in range(100): + torch.cuda.synchronize() + start_t = time.time() + func(B, G, C, H, W, values, deformables, weights, outs) + torch.cuda.synchronize() + t.append(time.time() - start_t) + t = t[-50:] + t = sum(t) / len(t) + if t < min_t: + min_t = t + DCNFunction.FW_TABLES[hash_value] = func + candicate_func = func + assert candicate_func is not None + print(candicate_func) + return candicate_func + @staticmethod + def find_bp_funcs(values, deformables, weights, grad_out): + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + hash_value = 10000 * B + 100 * H + W + 1000 * G + if hash_value in DCNFunction.BP_TABLES.keys(): + return DCNFunction.BP_TABLES[hash_value] + print("missing") + candicate_func = None + min_t = 999.0 + grad_values = torch.zeros_like(values) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + for func in DCNFunction.BP_FUNCS: + t = [] + for i in range(100): + torch.cuda.synchronize() + start_t = time.time() + func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) + torch.cuda.synchronize() + t.append(time.time() - start_t) + t = t[-50:] + t = sum(t) / len(t) + if t < min_t: + min_t = t + DCNFunction.BP_TABLES[hash_value] = func + candicate_func = func + assert candicate_func is not None + print(candicate_func) + return candicate_func + + @staticmethod + @once_differentiable + @custom_bwd + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_out = grad_outputs[0] + values, deformables, weights = ctx.saved_tensors + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out) + grad_values = torch.zeros_like(values) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) + return grad_values, grad_deformables, grad_weights \ No newline at end of file diff --git a/src/ops/cuda_kernels/setup.py b/src/ops/cuda_kernels/setup.py new file mode 100644 index 0000000..34079d4 --- /dev/null +++ b/src/ops/cuda_kernels/setup.py @@ -0,0 +1,59 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='dcn_cuda_forward', + ext_modules=[ + CUDAExtension('dcn_cuda_forward', ['./forward.cu',], + extra_compile_args={'cxx': [], 'nvcc': [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "--use_fast_math", + "-O3", + ]} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + +setup( + name='dcn_cuda_backward', + ext_modules=[ + CUDAExtension('dcn_cuda_backward', ['./backward.cu',], + extra_compile_args={'cxx': [], 'nvcc': [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "--use_fast_math", + "-O3", + ]} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + + +# setup( +# name='mycuda', +# ext_modules=[ +# CUDAExtension('mycuda', ['./backward.cu',], +# extra_compile_args={'cxx': [], 'nvcc': [ +# "-O3", +# "-DCUDA_HAS_FP16=1", +# "-D__CUDA_NO_HALF_OPERATORS__", +# "-D__CUDA_NO_HALF_CONVERSIONS__", +# "-D__CUDA_NO_HALF2_OPERATORS__", +# ]} +# ), +# ], +# cmdclass={ +# 'build_ext': BuildExtension +# } +# ) \ No newline at end of file diff --git a/src/ops/triton_kernels/__init__.py b/src/ops/triton_kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ops/triton_kernels/backward.py b/src/ops/triton_kernels/backward.py new file mode 100644 index 0000000..e886aa2 --- /dev/null +++ b/src/ops/triton_kernels/backward.py @@ -0,0 +1,124 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def backward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_groups + C: tl.constexpr, # num_channels_per_group + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + weights_ptr, # weights [B, H, W, G, K] + grad_ptr, # out [B, H, W, G, C] + grad_input_ptr, # input features [B, H, W, G, C] + grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + grad_weights_ptr, # weights [B, H, W, G, K] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + for k in tl.static_range(K): + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) + deformable_offset = (common_offset * K + k)*2 + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + for block_base in tl.static_range(0, C, BLOCK_SIZE): + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) + dods = weight*grad + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset + tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) + tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) + dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + tl_block_input_dot_grad * tl_weight + + dodtl = dods * tl_weight + tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) + + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) + tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) + dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) + dodw = dodw + tr_block_input_dot_grad*tr_weight + + dodtr = dods * tr_weight + tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) + + + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset + bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) + bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) + dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + bl_block_input_dot_grad*bl_weight + + dodbl = dods * bl_weight + tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) + + + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) + br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask + + dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) + dodw = dodw + br_block_input_dot_grad*br_weight + + dodbr = dods * br_weight + tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) + dodx = dodx * weight + dody = dody * weight + tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) + + + + + diff --git a/src/ops/triton_kernels/forward.py b/src/ops/triton_kernels/forward.py new file mode 100644 index 0000000..cf7c243 --- /dev/null +++ b/src/ops/triton_kernels/forward.py @@ -0,0 +1,94 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def forward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_channels_per_group + C: tl.constexpr, # num_groups + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] + weights_ptr, # weights [B, H, W, G, K] + out_ptr, # out [B, H, W, G, C] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + + for block_base in tl.static_range(0, C, BLOCK_SIZE): + buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + for k in tl.static_range(K): + deformable_offset = (common_offset * K + k) * 2 + + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) + + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + + + + tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input = tl_block_input * tl_weight + + # load top right + tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input = tr_block_input * tr_weight + # load bottom left + bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input = bl_block_input * bl_weight + # load bottom right + br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input = br_block_input * br_weight + + # sampled + sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input + + weighted_sampled_input = sampled_input * weight + buffer = buffer + weighted_sampled_input + # store to out_ptr + tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) + diff --git a/src/ops/triton_kernels/function.py b/src/ops/triton_kernels/function.py new file mode 100644 index 0000000..84987a1 --- /dev/null +++ b/src/ops/triton_kernels/function.py @@ -0,0 +1,48 @@ +import torch +import triton +from typing import Any +from torch.autograd import Function +from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd +from .forward import forward_kernel +from .backward import backward_kernel + + + +class DCNFunction(Function): + + @staticmethod + @custom_fwd + def forward(ctx: Any, inputs, deformables, weights) -> Any: + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + out = torch.zeros_like(inputs) + grid = lambda META: (B * H * W * G,) + + forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) + ctx.save_for_backward(inputs, deformables, weights) + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_output = grad_outputs[0].contiguous() + + inputs, deformables, weights = ctx.saved_tensors + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + + grad_inputs = torch.zeros_like(inputs) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + grid = lambda META: (B * H * W * G,) + backward_kernel[grid]( + B, H, W, G, C, K, + inputs, + deformables, + weights, + grad_output, + grad_inputs, + grad_deformables, + grad_weights, + ) + return (grad_inputs, grad_deformables, grad_weights) \ No newline at end of file diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plugins/bd_env.py b/src/plugins/bd_env.py new file mode 100644 index 0000000..c1900e9 --- /dev/null +++ b/src/plugins/bd_env.py @@ -0,0 +1,70 @@ +import torch +import os +import socket +from typing_extensions import override +from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.plugins.environments.lightning import LightningEnvironment + + +class BDEnvironment(LightningEnvironment): + pass + # def __init__(self) -> None: + # super().__init__() + # self._global_rank: int = 0 + # self._world_size: int = 1 + # + # @property + # @override + # def creates_processes_externally(self) -> bool: + # """Returns whether the cluster creates the processes or not. + # + # If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the + # process launcher/job scheduler and Lightning will not launch new processes. + # + # """ + # return "LOCAL_RANK" in os.environ + # + # @staticmethod + # @override + # def detect() -> bool: + # assert "ARNOLD_WORKER_0_HOST" in os.environ.keys() + # assert "ARNOLD_WORKER_0_PORT" in os.environ.keys() + # return True + # + # @override + # def world_size(self) -> int: + # return self._world_size + # + # @override + # def set_world_size(self, size: int) -> None: + # self._world_size = size + # + # @override + # def global_rank(self) -> int: + # return self._global_rank + # + # @override + # def set_global_rank(self, rank: int) -> None: + # self._global_rank = rank + # rank_zero_only.rank = rank + # + # @override + # def local_rank(self) -> int: + # return int(os.environ.get("LOCAL_RANK", 0)) + # + # @override + # def node_rank(self) -> int: + # return int(os.environ.get("ARNOLD_ID")) + # + # @override + # def teardown(self) -> None: + # if "WORLD_SIZE" in os.environ: + # del os.environ["WORLD_SIZE"] + # + # @property + # def main_address(self) -> str: + # return os.environ.get("ARNOLD_WORKER_0_HOST") + # + # @property + # def main_port(self) -> int: + # return int(os.environ.get("ARNOLD_WORKER_0_PORT")) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/copy.py b/src/utils/copy.py new file mode 100644 index 0000000..62cd89d --- /dev/null +++ b/src/utils/copy.py @@ -0,0 +1,13 @@ +import torch + +@torch.no_grad() +def copy_params(src_model, dst_model): + for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()): + dst_param.data.copy_(src_param.data) + +@torch.no_grad() +def swap_tensors(tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) \ No newline at end of file diff --git a/src/utils/model_loader.py b/src/utils/model_loader.py new file mode 100644 index 0000000..7d99166 --- /dev/null +++ b/src/utils/model_loader.py @@ -0,0 +1,29 @@ +from typing import Dict, Any, Optional + +import torch +import torch.nn as nn +from lightning.fabric.utilities.types import _PATH + + +import logging +logger = logging.getLogger(__name__) + +class ModelLoader: + def __init__(self,): + super().__init__() + + def load(self, denoiser, prefix=""): + if denoiser.weight_path: + weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) + + if denoiser.load_ema: + prefix = "ema_denoiser." + prefix + else: + prefix = "denoiser." + prefix + + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + logger.warning(f"Failed to copy {prefix+k} to denoiser weight") + return denoiser \ No newline at end of file diff --git a/src/utils/no_grad.py b/src/utils/no_grad.py new file mode 100644 index 0000000..2fd71de --- /dev/null +++ b/src/utils/no_grad.py @@ -0,0 +1,16 @@ +import torch + +@torch.no_grad() +def no_grad(net): + for param in net.parameters(): + param.requires_grad = False + net.eval() + return net + +@torch.no_grad() +def filter_nograd_tensors(params_list): + filtered_params_list = [] + for param in params_list: + if param.requires_grad: + filtered_params_list.append(param) + return filtered_params_list \ No newline at end of file diff --git a/src/utils/patch_bugs.py b/src/utils/patch_bugs.py new file mode 100644 index 0000000..db9a174 --- /dev/null +++ b/src/utils/patch_bugs.py @@ -0,0 +1,17 @@ +import torch +import lightning.pytorch.loggers.wandb as wandb + +setattr(wandb, '_WANDB_AVAILABLE', True) +torch.set_float32_matmul_precision('medium') + +import logging +logger = logging.getLogger("wandb") +logger.setLevel(logging.WARNING) + +import os +os.environ["NCCL_DEBUG"] = "WARN" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=UserWarning) diff --git a/tools/cache_imlatent3.py b/tools/cache_imlatent3.py new file mode 100644 index 0000000..640cdb0 --- /dev/null +++ b/tools/cache_imlatent3.py @@ -0,0 +1,117 @@ +from diffusers import AutoencoderKL + +import torch +from typing import Callable +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import torch +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import pathlib + +import torch +import random +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + writer_pool = ThreadPoolExecutor(8) + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + # tvtf.RandomHorizontalFlip(p=1), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + # dataset = ImageNet(root='/tmp', split="train", transform=transforms, ) + B = 256 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16) + vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda') + + from accelerate import Accelerator + + accelerator = Accelerator() + + vae, dataloader = accelerator.prepare(vae, dataloader) + rank = accelerator.process_index + with torch.no_grad(): + for i, (image, label, path_list) in enumerate(dataloader): + # if i >= 128: break + new_path_list = [] + for p in path_list: + p = p + ".pt" + p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train", + "/mnt/bn/wangshuai6/data/ImageNet/train_256latent") + new_path_list.append(p) + + image = image.to("cuda") + distribution = vae.module.encode(image).latent_dist + mean = distribution.mean + logvar = distribution.logvar + for j in range(B): + out = dict( + mean=mean[j].cpu(), + logvar=logvar[j].cpu(), + ) + writer_pool.submit(save, out, new_path_list[j]) + writer_pool.shutdown(wait=True) + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/cache_imlatent4.py b/tools/cache_imlatent4.py new file mode 100644 index 0000000..fc33fa7 --- /dev/null +++ b/tools/cache_imlatent4.py @@ -0,0 +1,123 @@ +from diffusers import AutoencoderKL + +import torch +from typing import Callable +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import torch +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import pathlib + +import torch +import random +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + writer_pool = ThreadPoolExecutor(8) + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(512), + # tvtf.RandomHorizontalFlip(p=1), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 8 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16) + vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda') + vae = vae.to(torch.float16) + from accelerate import Accelerator + + accelerator = Accelerator() + + vae, dataloader = accelerator.prepare(vae, dataloader) + rank = accelerator.process_index + with torch.no_grad(): + for i, (image, label, path_list) in enumerate(dataloader): + print(i/len(dataloader)) + flag = False + new_path_list = [] + for p in path_list: + p = p + ".pt" + p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train", + "/mnt/bn/wangshuai6/data/ImageNet/train_512_latent") + new_path_list.append(p) + if not os.path.exists(p): + print(p) + flag = True + + if flag: + image = image.to("cuda") + image = image.to(torch.float16) + distribution = vae.module.encode(image).latent_dist + mean = distribution.mean + logvar = distribution.logvar + + for j in range(len(path_list)): + out = dict( + mean=mean[j].cpu(), + logvar=logvar[j].cpu(), + ) + writer_pool.submit(save, out, new_path_list[j]) + writer_pool.shutdown(wait=True) + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/cat_images.py b/tools/cat_images.py new file mode 100644 index 0000000..21734b0 --- /dev/null +++ b/tools/cat_images.py @@ -0,0 +1,43 @@ +import cv2 +import numpy as np +import os +import pathlib +import argparse + +def group_images(path_list): + sorted(path_list) + class_id_dict = {} + for path in path_list: + class_id = str(path.name).split('_')[0] + if class_id not in class_id_dict: + class_id_dict[class_id] = [] + class_id_dict[class_id].append(path) + return class_id_dict + +def cat_images(path_list): + imgs = [] + for path in path_list: + img = cv2.imread(str(path)) + os.remove(path) + imgs.append(img) + row_cat_images = [] + row_length = int(len(imgs)**0.5) + for i in range(len(imgs)//row_length): + row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1)) + cat_image = np.concatenate(row_cat_images, axis=0) + return cat_image + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--src_dir', type=str, default=None) + + args = parser.parse_args() + src_dir = args.src_dir + path_list = list(pathlib.Path(src_dir).glob('*.png')) + class_id_dict = group_images(path_list) + for class_id, path_list in class_id_dict.items(): + cat_image = cat_images(path_list) + cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg') + # cat_path = "cat_{}.png".format(class_id) + cv2.imwrite(cat_path, cat_image) + diff --git a/tools/classifer_training.py b/tools/classifer_training.py new file mode 100644 index 0000000..b00b435 --- /dev/null +++ b/tools/classifer_training.py @@ -0,0 +1,353 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +# NORMALIZE_DATA = dict( +# dinov2_vits14a = dict( +# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79, +# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49, +# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67, +# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12, +# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43, +# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61, +# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23, +# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57, +# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65, +# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13, +# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27, +# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94, +# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07, +# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56, +# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50, +# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29, +# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78, +# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03, +# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02, +# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25, +# ], +# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92, +# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90, +# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03, +# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21, +# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04, +# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68, +# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71, +# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88, +# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86, +# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44, +# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15, +# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49, +# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90, +# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87, +# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81, +# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89, +# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97, +# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96, +# 2.62,3.23,2.13,2.29,2.24,2.72 +# ] +# ), +# dinov2_vitb14a = dict( +# mean=[ +# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82, +# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17, +# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31, +# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15, +# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03, +# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54, +# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43, +# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26 +# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05, +# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19, +# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36, +# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00, +# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14, +# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44, +# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24, +# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25, +# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12, +# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36, +# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10, +# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22, +# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28, +# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23, +# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62, +# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26, +# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12, +# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07, +# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57, +# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62, +# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10, +# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51, +# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22, +# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36, +# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10, +# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02, +# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42, +# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12, +# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60, +# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35, +# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64, +# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31 +# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00, +# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73, +# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05, +# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35, +# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05, +# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54, +# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03, +# ], +# std=[ +# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40, +# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39, +# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56, +# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73, +# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49, +# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42, +# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47, +# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49, +# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05, +# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46, +# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67, +# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91, +# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48, +# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76, +# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37, +# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52, +# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42, +# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56, +# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43, +# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47, +# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54, +# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49, +# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59, +# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52, +# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55, +# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55, +# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38, +# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38, +# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53, +# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49, +# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41, +# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43, +# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68, +# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47, +# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43, +# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54, +# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44, +# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69, +# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50, +# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33, +# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58, +# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47, +# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29 +# ] +# ) +# ) + +class DINOv2a(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2a, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + # self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False) + # self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False) + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + # feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1) + feature = feature.transpose(1, 2) + feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2) + return feature + + + +from torchvision.datasets import ImageFolder, ImageNet + +import os +import numpy as np + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +import math +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class Classifer(nn.Module): + def __init__(self, in_channels=192, hidden_size=256, num_classes=1000): + super(Classifer, self).__init__() + self.in_channels = in_channels + self.feature_x = nn.Sequential( + nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0), + nn.AdaptiveAvgPool2d(1), + ) + def forward(self, xt): + xt = xt[:, :self.in_channels] + score = self.feature_x(xt).squeeze(-1).squeeze(-1) + # score = (feature_xt).clamp(-5, 5) + score = torch.softmax(score, dim=1) + return score + + + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 64 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True) + dino = DINOv2a("dinov2_vitb14") + from accelerate import Accelerator + + accelerator = Accelerator() + + rank = accelerator.process_index + + classifer = Classifer(in_channels=32) + classifer.train() + optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001) + + dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer) + + # fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance" + # fake_file_names = os.listdir(fake_file_dir) + + for epoch in range(100): + for i, (true_images, true_labels, path_list) in enumerate(dataloader): + batch_size = true_images.shape[0] + true_labels = true_labels.to(accelerator.device) + true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000) + with torch.no_grad(): + true_dino_feature = dino(true_images) + # t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device) + # true_x_t = t * true_dino_feature + (1-t) * noise + + true_x_t = true_dino_feature + true_score = classifer(true_x_t) + + # ind = i % len(fake_file_names) + # fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind])) + # import pdb; pdb.set_trace() + # ind = torch.randint(0, 50, size=(4,)) + # fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :] + # fake_labels = fake_file['condition'].repeat(4) + # fake_score = classifer(fake_x_t) + + loss_true = -torch.log(true_score)*true_labels + loss = loss_true.sum()/batch_size + loss.backward() + optimizer.step() + optimizer.zero_grad() + + acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size + if accelerator.is_main_process: + print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item())) + if accelerator.is_main_process: + torch.save(classifer.state_dict(), f'{epoch}.pth') + + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/debug_env.sh b/tools/debug_env.sh new file mode 100644 index 0000000..d29dc46 --- /dev/null +++ b/tools/debug_env.sh @@ -0,0 +1,4 @@ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat +pip3 install -r requirements.txt +git branch --set-upstream-to=origin/master master +git pull \ No newline at end of file diff --git a/tools/dino_scale.py b/tools/dino_scale.py new file mode 100644 index 0000000..439caaf --- /dev/null +++ b/tools/dino_scale.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + +from diffusers import AutoencoderKL + +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + # tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 4096 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0) + dino = DINOv2("dinov2_vitb14") + # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae") + # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai") + from accelerate import Accelerator + + accelerator = Accelerator() + dino, dataloader = accelerator.prepare(dino, dataloader) + rank = accelerator.process_index + + acc_mean = torch.zeros((768, ), device=accelerator.device) + acc_num = 0 + with torch.no_grad(): + for i, (images, labels, path_list) in enumerate(dataloader): + acc_num += len(images) + feature = dino(images) + stds = torch.std(feature, dim=[0, 2, 3]).tolist() + for std in stds: + print("{:.2f},".format(std), end='') + print() + means = torch.mean(feature, dim=[0, 2, 3]).tolist() + for mean in means: + print("{:.2f},".format(mean), end='') + break + + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/dino_scale2.py b/tools/dino_scale2.py new file mode 100644 index 0000000..336c196 --- /dev/null +++ b/tools/dino_scale2.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + +from diffusers import AutoencoderKL + +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 2048 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0) + dino = DINOv2("dinov2_vitb14") + # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae") + # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai") + from accelerate import Accelerator + + accelerator = Accelerator() + dino, dataloader = accelerator.prepare(dino, dataloader) + rank = accelerator.process_index + + with torch.no_grad(): + for i, (images, labels, path_list) in enumerate(dataloader): + feature = dino(images) + b, c, h, w = feature.shape + feature = feature.view(b, c, h*w).transpose(1, 2) + feature = feature.reshape(-1, c) + U, S, V = torch.pca_lowrank(feature, 64, ) + import pdb; pdb.set_trace() + feature = torch.matmul(feature, V) + break + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/dp.py b/tools/dp.py new file mode 100644 index 0000000..9f89d75 --- /dev/null +++ b/tools/dp.py @@ -0,0 +1,64 @@ +import matplotlib.pyplot as plt +print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250])) +print(len(list(range(0, 251, 5)))) +exit() +plt.plot() +plt.plot() +plt.show() +exit() + + + + +import torch + +num_steps = 10 +num_recompute_timesteps = 4 +sim = torch.randint(0, 100, (num_steps, num_steps)) +sim[:5, :5] = 100 +for i in range(num_steps): + sim[i, i] = 100 + +error_map = (100-sim).tolist() + + +# init +for i in range(1, num_steps): + for j in range(0, i): + error_map[i][j] = error_map[i-1][j] + error_map[i][j] + +C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)] +P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)] + +for i in range(1, num_steps+1): + C[1][i] = error_map[i-1][0] + P[1][i] = 0 + + +# dp +for step in range(2, num_recompute_timesteps+1): + for i in range(step, num_steps+1): + min_value = 99999 + min_index = -1 + for j in range(step-1, i): + value = C[step-1][j] + error_map[i-1][j] + if value < min_value: + min_value = value + min_index = j + C[step][i] = min_value + P[step][i] = min_index + +# trace back +tracback_end_index = num_steps +# min_value = 99999 +# for i in range(num_recompute_timesteps-1, num_steps): +# if C[-1][i] < min_value: +# min_value = C[-1][i] +# tracback_end_index = i + +timesteps = [tracback_end_index, ] +for i in range(num_recompute_timesteps, 0, -1): + idx = timesteps[-1] + timesteps.append(P[i][idx]) +timesteps.reverse() +print(timesteps) \ No newline at end of file diff --git a/tools/figures/base++.py b/tools/figures/base++.py new file mode 100644 index 0000000..9715321 --- /dev/null +++ b/tools/figures/base++.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib.pyplot as plt + +is_data = { + "4encoder8decoder":[46.01, 61.47, 69.73, 74.26], + "6encoder6decoder":[53.11, 71.04, 79.83, 83.85], + "8encoder4decoder":[54.06, 72.96, 80.49, 85.94], + "10encoder2decoder": [49.25, 67.59, 76.00, 81.12], +} + +fid_data = { + "4encoder8decoder":[31.40, 22.80, 20.13, 18.61], + "6encoder6decoder":[27.61, 20.42, 17.95, 16.86], + "8encoder4decoder":[27.12, 19.90, 17.78, 16.32], + "10encoder2decoder": [29.70, 21.75, 18.95, 17.65], +} + +sfid_data = { + "4encoder8decoder":[6.88, 6.44, 6.56, 6.56], + "6encoder4decoder":[6.83, 6.50, 6.49, 6.63], + "8encoder4decoder":[6.76, 6.70, 6.83, 6.63], + "10encoder2decoder": [6.81, 6.61, 6.53, 6.60], +} + +pr_data = { + "4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922], + "6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702], + "8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132], + "10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686], +} + +recall_data = { + "4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662], + "6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589], + "8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618], + "10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569], +} + +x = [100, 200, 300, 400] +# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] + +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] + +metric_data = { + "FID50K" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10) + plt.legend(fontsize="14") + plt.xticks([100, 150, 200, 250, 300, 350, 400]) + plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',) + plt.close() \ No newline at end of file diff --git a/tools/figures/base.py b/tools/figures/base.py new file mode 100644 index 0000000..46acd3c --- /dev/null +++ b/tools/figures/base.py @@ -0,0 +1,57 @@ +import numpy as np +import matplotlib.pyplot as plt + + + + +fid_data = { + "4encoder8decoder":[64.16, 48.04, 39.88, 35.41], + "6encoder4decoder":[67.71, 48.26, 39.30, 34.91], + "8encoder4decoder":[69.4, 49.7, 41.56, 36.76], +} + +sfid_data = { + "4encoder8decoder":[7.86, 7.48, 7.15, 7.07], + "6encoder4decoder":[8.54, 8.11, 7.40, 7.40], + "8encoder4decoder":[8.42, 8.27, 8.10, 7.69], +} + +is_data = { + "4encoder8decoder":[20.37, 29.41, 36.88, 41.32], + "6encoder4decoder":[20.04, 30.13, 38.17, 43.84], + "8encoder4decoder":[19.98, 29.54, 35.93, 42.025], +} + +pr_data = { + "4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271], + "6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266], + "8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162], +} + +recall_data = { + "4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338], + "6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378], + "8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333], +} + +x = [100, 200, 300, 400] +colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] +metric_data = { + "FID" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o") + plt.legend() + plt.xticks(x) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/tools/figures/cfg.py b/tools/figures/cfg.py new file mode 100644 index 0000000..4cd8855 --- /dev/null +++ b/tools/figures/cfg.py @@ -0,0 +1,32 @@ +import numpy as np +import matplotlib.pyplot as plt + +cfg_data = { + "[0, 1]":{ + "cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], + "FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13], + }, + "[0.2, 1]":{ + "cfg": [1.2, 1.4, 1.6, 1.8, 2.0], + "FID": [5.87, 4.44, 3.96, 4.01, 4.26] + }, + "[0.3, 1]":{ + "cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4], + "FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03] + }, + "[0.35, 1]":{ + "cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6], + "FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94] + } +} + +colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"] + +for i, (name, data) in enumerate(cfg_data.items()): + plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o") + +plt.title("Classifer-free guidance with intervals", weight="bold") +plt.ylabel("FID10K", weight="bold") +plt.xlabel("CFG values", weight="bold") +plt.legend() +plt.savefig("./output/cfg.pdf", bbox_inches="tight") \ No newline at end of file diff --git a/tools/figures/feat_vis.py b/tools/figures/feat_vis.py new file mode 100644 index 0000000..55f2045 --- /dev/null +++ b/tools/figures/feat_vis.py @@ -0,0 +1,42 @@ +import torch + +states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32) +states = states.permute(1, 2, 0, 3) +print(states.shape) +states = states.view(-1, 49, 1152) +states = torch.nn.functional.normalize(states, dim=-1) +sim = torch.bmm(states, states.transpose(1, 2)) +mean_sim = torch.mean(sim, dim=0, keepdim=False) + +mean_sim = mean_sim.numpy() +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +timesteps = np.linspace(0, 1, 5) +# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False}) +cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"]) +plt.imshow(mean_sim, cmap="inferno") +plt.xticks([]) +plt.yticks([]) +# plt.show() +plt.colorbar() +plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight") +# cos_sim = torch.nn.functional.cosine_similarity(states, states) + + +# for i in range(49): +# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1]) +# cos_sim = cos_sim.min() +# print(cos_sim) +# state = torch.max(states, dim=-1)[1] +# # state = torch.softmax(state, dim=-1) +# state = state.view(-1, 16, 16) +# +# state = state.numpy() +# +# import numpy as np +# import matplotlib.pyplot as plt +# for i in range(0, 49): +# print(i) +# plt.imshow(state[i]) +# plt.savefig("./output2/{}.png".format(i)) \ No newline at end of file diff --git a/tools/figures/large++.py b/tools/figures/large++.py new file mode 100644 index 0000000..070d13f --- /dev/null +++ b/tools/figures/large++.py @@ -0,0 +1,63 @@ +import numpy as np +import matplotlib.pyplot as plt + +is_data = { + "10encoder14decoder":[80.48, 104.48, 113.01, 117.29], + "12encoder12decoder":[85.52, 109.91, 118.18, 121.77], + "16encoder8decoder":[92.72, 116.30, 124.32, 126.37], + "20encoder4decoder":[94.95, 117.84, 125.66, 128.30], +} + +fid_data = { + "10encoder14decoder":[15.17, 10.40, 9.32, 8.66], + "12encoder12decoder":[13.79, 9.67, 8.64, 8.21], + "16encoder8decoder":[12.41, 8.99, 8.18, 8.03], + "20encoder4decoder":[12.04, 8.94, 8.03, 7.98], +} + +sfid_data = { + "10encoder14decoder":[5.49, 5.00, 5.09, 5.14], + "12encoder12decoder":[5.37, 5.01, 5.07, 5.09], + "16encoder8decoder":[5.43, 5.11, 5.20, 5.31], + "20encoder4decoder":[5.36, 5.23, 5.21, 5.50], +} + +pr_data = { + "10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104], + "12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823], + "16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912], + "20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098], +} + +recall_data = { + "10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679], + "12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693], + "16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773], + "20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711], +} + +x = [100, 200, 300, 400] +# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] + +metric_data = { + "FID50K" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8) + plt.legend(fontsize="14") + plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) + plt.xticks([100, 150, 200, 250, 300, 350, 400]) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/tools/figures/log_snr.py b/tools/figures/log_snr.py new file mode 100644 index 0000000..f9d1b28 --- /dev/null +++ b/tools/figures/log_snr.py @@ -0,0 +1,18 @@ +import numpy as np +import matplotlib.pyplot as plt + +t = np.linspace(0.001, 0.999, 100) +def snr(t): + return np.log((1-t)/t) +def pds(t): + return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0) +print(pds(t)) +plt.figure(figsize=(16, 4)) +plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o") +# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o") +plt.ylabel("log-SNR", weight="bold") +plt.xlabel("Timesteps", weight="bold") +plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0]) +plt.gca().invert_xaxis() +plt.show() +# plt.savefig("output/logsnr.pdf", bbox_inches='tight') \ No newline at end of file diff --git a/tools/figures/output/base++_FID.pdf b/tools/figures/output/base++_FID.pdf new file mode 100644 index 0000000..4aa3a98 Binary files /dev/null and b/tools/figures/output/base++_FID.pdf differ diff --git a/tools/figures/output/base++_FID50K.pdf b/tools/figures/output/base++_FID50K.pdf new file mode 100644 index 0000000..779138f Binary files /dev/null and b/tools/figures/output/base++_FID50K.pdf differ diff --git a/tools/figures/output/base++_InceptionScore.pdf b/tools/figures/output/base++_InceptionScore.pdf new file mode 100644 index 0000000..03eceb5 Binary files /dev/null and b/tools/figures/output/base++_InceptionScore.pdf differ diff --git a/tools/figures/output/base++_Precision.pdf b/tools/figures/output/base++_Precision.pdf new file mode 100644 index 0000000..38f410a Binary files /dev/null and b/tools/figures/output/base++_Precision.pdf differ diff --git a/tools/figures/output/base++_Recall.pdf b/tools/figures/output/base++_Recall.pdf new file mode 100644 index 0000000..e01ed34 Binary files /dev/null and b/tools/figures/output/base++_Recall.pdf differ diff --git a/tools/figures/output/base_FID.pdf b/tools/figures/output/base_FID.pdf new file mode 100644 index 0000000..5541a6b Binary files /dev/null and b/tools/figures/output/base_FID.pdf differ diff --git a/tools/figures/output/base_InceptionScore.pdf b/tools/figures/output/base_InceptionScore.pdf new file mode 100644 index 0000000..ad8ae23 Binary files /dev/null and b/tools/figures/output/base_InceptionScore.pdf differ diff --git a/tools/figures/output/base_Precision.pdf b/tools/figures/output/base_Precision.pdf new file mode 100644 index 0000000..dc16562 Binary files /dev/null and b/tools/figures/output/base_Precision.pdf differ diff --git a/tools/figures/output/base_Recall.pdf b/tools/figures/output/base_Recall.pdf new file mode 100644 index 0000000..7164c9e Binary files /dev/null and b/tools/figures/output/base_Recall.pdf differ diff --git a/tools/figures/output/cfg.pdf b/tools/figures/output/cfg.pdf new file mode 100644 index 0000000..90eec86 Binary files /dev/null and b/tools/figures/output/cfg.pdf differ diff --git a/tools/figures/output/large++_FID.pdf b/tools/figures/output/large++_FID.pdf new file mode 100644 index 0000000..ec04806 Binary files /dev/null and b/tools/figures/output/large++_FID.pdf differ diff --git a/tools/figures/output/large++_FID50K.pdf b/tools/figures/output/large++_FID50K.pdf new file mode 100644 index 0000000..b620e46 Binary files /dev/null and b/tools/figures/output/large++_FID50K.pdf differ diff --git a/tools/figures/output/large++_InceptionScore.pdf b/tools/figures/output/large++_InceptionScore.pdf new file mode 100644 index 0000000..0c44d4b Binary files /dev/null and b/tools/figures/output/large++_InceptionScore.pdf differ diff --git a/tools/figures/output/large++_Precision.pdf b/tools/figures/output/large++_Precision.pdf new file mode 100644 index 0000000..f5698ad Binary files /dev/null and b/tools/figures/output/large++_Precision.pdf differ diff --git a/tools/figures/output/large++_Recall.pdf b/tools/figures/output/large++_Recall.pdf new file mode 100644 index 0000000..2671f07 Binary files /dev/null and b/tools/figures/output/large++_Recall.pdf differ diff --git a/tools/figures/output/logsnr.pdf b/tools/figures/output/logsnr.pdf new file mode 100644 index 0000000..3d46a63 Binary files /dev/null and b/tools/figures/output/logsnr.pdf differ diff --git a/tools/figures/output/mean_sim.png b/tools/figures/output/mean_sim.png new file mode 100644 index 0000000..f4b7967 Binary files /dev/null and b/tools/figures/output/mean_sim.png differ diff --git a/tools/figures/output/sota.pdf b/tools/figures/output/sota.pdf new file mode 100644 index 0000000..80b7263 Binary files /dev/null and b/tools/figures/output/sota.pdf differ diff --git a/tools/figures/output/timeshift.pdf b/tools/figures/output/timeshift.pdf new file mode 100644 index 0000000..b2cb2f9 Binary files /dev/null and b/tools/figures/output/timeshift.pdf differ diff --git a/tools/figures/output/timeshift_fid.pdf b/tools/figures/output/timeshift_fid.pdf new file mode 100644 index 0000000..1b3b7c7 Binary files /dev/null and b/tools/figures/output/timeshift_fid.pdf differ diff --git a/tools/figures/sota.py b/tools/figures/sota.py new file mode 100644 index 0000000..58b0ec9 --- /dev/null +++ b/tools/figures/sota.py @@ -0,0 +1,95 @@ +import numpy as np +import matplotlib.pyplot as plt + +data = { + "SiT-XL/2" : { + "size": 675, + "epochs": 1400, + "FID": 2.06, + "color": "#ff99c8" + }, + "DiT-XL/2" : { + "size": 675, + "epochs": 1400, + "FID": 2.27, + "color": "#fcf6bd" + }, + "REPA-XL/2" : { + "size": 675, + "epochs": 800, + "FID": 1.42, + "color": "#d0f4de" + }, + # "MAR-H" : { + # "size": 973, + # "epochs": 800, + # "FID": 1.55, + # }, + "MDTv2" : { + "size": 675, + "epochs": 920, + "FID": 1.58, + "color": "#e4c1f9" + }, + # "VAVAE+LightningDiT" : { + # "size": 675, + # "epochs": [64, 800], + # "FID": [2.11, 1.35], + # }, + "DDT-XL/2": { + "size": 675, + "epochs": [80, 256], + "FID": [1.52, 1.31], + "color": "#38a3a5" + }, + "DDT-L/2": { + "size": 400, + "epochs": 80, + "FID": 1.64, + "color": "#5bc0be" + }, +} + +fig = plt.figure() +ax = fig.add_subplot(1, 1, 1) +for k, spec in data.items(): + plt.scatter( + # spec["size"], + spec["epochs"], + spec["FID"], + label=k, + marker="o", + s=spec["size"], + color=spec["color"], + ) + x = spec["epochs"] + y = spec["FID"] + if isinstance(spec["FID"], list): + x = spec["epochs"][-1] + y = spec["FID"][-1] + plt.plot( + spec["epochs"], + spec["FID"], + color=spec["color"], + linestyle="dotted", + linewidth=4 + ) + # plt.annotate("", + # xytext=(spec["epochs"][0], spec["FID"][0]), + # xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold") + plt.text(x+80, y-0.05, k, fontsize=13) + +plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold") +# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05) + +plt.annotate("", + xy=(700, 1.42), xytext=(200, 1.42), + arrowprops=dict(arrowstyle='<->', color='black', linewidth=2), + ) +ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5) +plt.gca().set_xlim(0, 1800) +plt.gca().set_ylim(1.15, 2.5) +plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ]) +plt.xlabel("Training Epochs", weight="bold") +plt.ylabel("FID50K on ImageNet256x256", weight="bold") +plt.savefig("output/sota.pdf", bbox_inches="tight") \ No newline at end of file diff --git a/tools/figures/timeshift.py b/tools/figures/timeshift.py new file mode 100644 index 0000000..766fac7 --- /dev/null +++ b/tools/figures/timeshift.py @@ -0,0 +1,26 @@ +import scipy +import numpy as np +import matplotlib.pyplot as plt + +def timeshift(t, s=1.0): + return t/(t+(1-t)*s) + +# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"] +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] +# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False}) +t = np.linspace(0, 1, 100) +shifts = [1.0, 1.5, 2, 3] +for i , shift in enumerate(shifts): + plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4) + +# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold") +# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold") +# plt.title("Respaced timesteps with various shift value", weight="bold") +# plt.gca().set_xlim(0, 1.0) +# plt.gca().set_ylim(0, 1.0) +plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) + +plt.ylabel("Respaced Timesteps", weight="bold") +plt.xlabel("Uniform Timesteps", weight="bold") +plt.legend(loc="upper left", fontsize="12") +plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0) \ No newline at end of file diff --git a/tools/figures/timeshift_fid.py b/tools/figures/timeshift_fid.py new file mode 100644 index 0000000..07a591c --- /dev/null +++ b/tools/figures/timeshift_fid.py @@ -0,0 +1,29 @@ +import scipy +import numpy as np +import matplotlib.pyplot as plt + +def timeshift(t, s=1.0): + return t/(t+(1-t)*s) + +data = { + "shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80], + "shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33], + "shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26], + "shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38], +} +# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False}) + +# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"] + +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] +steps = [5, 6, 7, 8, 9, 10, 12] +for i ,(k, v)in enumerate(data.items()): + plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o") + +# plt.title("FID50K of different steps of different timeshift", weight="bold") +plt.ylabel("FID50K", weight="bold") +plt.xlabel("Num of inference steps", weight="bold") +plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) +# plt.legend() +# plt.legend() +plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0) \ No newline at end of file diff --git a/tools/fm_images.py b/tools/fm_images.py new file mode 100644 index 0000000..6657523 --- /dev/null +++ b/tools/fm_images.py @@ -0,0 +1,21 @@ +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +PATH = "/mnt/bn/wangshuai6/neural_sampling_workdirs/expbaseline_adam2_timeshift1.5" + +import os +import pathlib +PATH = pathlib.Path(PATH) +images = [] + +# find images +for ext in IMG_EXTENSIONS: + images.extend(PATH.rglob(ext)) + +for image in images: + os.system(f"rm -f {image}") + diff --git a/tools/mm.py b/tools/mm.py new file mode 100644 index 0000000..25bba93 --- /dev/null +++ b/tools/mm.py @@ -0,0 +1,23 @@ +import torch +import time +import torch.nn as nn +import accelerate + + +if __name__ == "__main__": + model = nn.Linear(512, 512) + for p in model.parameters(): + p.requires_grad = False + accelerator = accelerate.Accelerator() + model = accelerator.prepare_model(model) + model.to(accelerator.device) + data = torch.randn(1024, 512).to(accelerator.device) + while True: + time.sleep(0.01) + accelerator.wait_for_everyone() + if torch.cuda.utilization() < 1.5: + with torch.no_grad(): + model(data) + else: + time.sleep(1) + # print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}") \ No newline at end of file diff --git a/tools/sigmoid.py b/tools/sigmoid.py new file mode 100644 index 0000000..90cec55 --- /dev/null +++ b/tools/sigmoid.py @@ -0,0 +1,20 @@ +import numpy as np +import matplotlib.pyplot as plt + + +def lw(x, b=0): + x = np.clip(x, a_min=0.001, a_max=0.999) + snr = x/(1-x) + logsnr = np.log(snr) + # print(logsnr) + # return logsnr + weight = 1 / (1 + np.exp(-logsnr - b))#*(1-x)**2 + return weight #/weight.max() + +x = np.arange(0.2, 0.8, 0.001) +print(1/(x*(1-x))) +for b in [0, 1, 2, 3]: + y = lw(x, b) + plt.plot(x, y, label=f"b={b}") +plt.legend() +plt.show() \ No newline at end of file diff --git a/tools/vae2dino.py b/tools/vae2dino.py new file mode 100644 index 0000000..136f5ff --- /dev/null +++ b/tools/vae2dino.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + +from diffusers import AutoencoderKL + +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 4 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0) + # dino = DINOv2("dinov2_vitb14") + dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae") + # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai") + from accelerate import Accelerator + + accelerator = Accelerator() + dino, dataloader = accelerator.prepare(dino, dataloader) + rank = accelerator.process_index + + acc_mean = torch.zeros((768, ), device=accelerator.device) + acc_num = 0 + with torch.no_grad(): + for i, (images, labels, path_list) in enumerate(dataloader): + acc_num += len(images) + feature = dino(images) + stds = torch.std(feature, dim=[0, 2, 3]).tolist() + for std in stds: + print("{:.2f},".format(std), end='') + print() + means = torch.mean(feature, dim=[0, 2, 3]).tolist() + for mean in means: + print("{:.2f},".format(mean), end='') + break + + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/vis_timeshift.py b/tools/vis_timeshift.py new file mode 100644 index 0000000..496b922 --- /dev/null +++ b/tools/vis_timeshift.py @@ -0,0 +1,23 @@ +import scipy +import numpy as np +import matplotlib.pyplot as plt + +def timeshift(t, s=1.0): + return t/(t+(1-t)*s) + +def gaussian(t): + gs = 1+scipy.special.erf((t-t.mean())/t.std()) + +def rs2(t, s=2.0): + factor1 = 1.0 #s/(s+(1-s)*t)**2 + factor2 = np.log(t.clip(0.001, 0.999)/(1-t).clip(0.001, 0.999)) + return factor1*factor2 + + +t = np.linspace(0, 1, 100) +# plt.plot(t, timeshift(t, 1.0)) +respaced_t = timeshift(t, s=5) +delats = (respaced_t[1:] - respaced_t[:-1]) +# plt.plot(t, timeshift(t, 1.5)) +plt.plot(rs2(t)) +plt.show() \ No newline at end of file