From 06499f1caad926bd0d98594d448205c4c3c3108f Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Wed, 9 Apr 2025 11:01:16 +0800 Subject: [PATCH] submit code --- .idea/.gitignore | 8 + .idea/DDT.iml | 8 + .../inspectionProfiles/profiles_settings.xml | 6 + .idea/modules.xml | 8 + .idea/vcs.xml | 6 + configs/repa_flatten_condit22_fixt_xl.yaml | 108 ++++ configs/repa_flatten_condit22_fixt_xl512.yaml | 108 ++++ configs/repa_flatten_dit_fixt_large.yaml | 99 ++++ configs/repa_flatten_dit_fixt_xl.yaml | 99 ++++ main.py | 86 ++++ main.sh | 10 + requirements.txt | 3 + src/__init__.py | 0 src/callbacks/__init__.py | 0 src/callbacks/grad.py | 22 + src/callbacks/model_checkpoint.py | 25 + src/callbacks/save_images.py | 105 ++++ src/callbacks/simple_ema.py | 79 +++ src/data/__init__.py | 1 + src/data/dataset/__init__.py | 0 src/data/dataset/celeba.py | 11 + src/data/dataset/imagenet.py | 82 ++++ src/data/dataset/metric_dataset.py | 82 ++++ src/data/dataset/randn.py | 41 ++ src/data/var_training.py | 145 ++++++ src/diffusion/__init__.py | 0 src/diffusion/base/guidance.py | 60 +++ src/diffusion/base/sampling.py | 31 ++ src/diffusion/base/scheduling.py | 32 ++ src/diffusion/base/training.py | 29 ++ src/diffusion/ddpm/ddim_sampling.py | 40 ++ src/diffusion/ddpm/scheduling.py | 102 ++++ src/diffusion/ddpm/training.py | 83 ++++ src/diffusion/ddpm/vp_sampling.py | 59 +++ src/diffusion/flow_matching/adam_sampling.py | 107 ++++ src/diffusion/flow_matching/sampling.py | 179 +++++++ src/diffusion/flow_matching/scheduling.py | 39 ++ src/diffusion/flow_matching/training.py | 55 +++ src/diffusion/flow_matching/training_cos.py | 59 +++ .../flow_matching/training_pyramid.py | 68 +++ src/diffusion/flow_matching/training_repa.py | 142 ++++++ .../flow_matching/training_repa_mask.py | 152 ++++++ src/diffusion/pre_integral.py | 143 ++++++ .../stateful_flow_matching/adam_sampling.py | 112 +++++ .../bak/training_adv.py | 122 +++++ .../bak/training_adv_x0.py | 127 +++++ .../bak/training_mask_repa.py | 159 ++++++ .../bak/training_patch_adv.py | 179 +++++++ .../bak/training_repa_jit.py | 154 ++++++ .../bak/training_self_consistent.py | 90 ++++ .../bak/training_selflpips.py | 81 +++ .../stateful_flow_matching/cm_sampling.py | 78 +++ .../stateful_flow_matching/sampling.py | 103 ++++ .../stateful_flow_matching/scheduling.py | 39 ++ .../sharing_sampling.py | 149 ++++++ .../stateful_flow_matching/training.py | 55 +++ .../stateful_flow_matching/training_adv.py | 122 +++++ .../training_distill_dino.py | 141 ++++++ .../stateful_flow_matching/training_lpips.py | 71 +++ .../training_lpips_lossweight.py | 74 +++ .../stateful_flow_matching/training_repa.py | 157 ++++++ .../training_repa_lpips.py | 170 +++++++ src/lightning_data.py | 162 ++++++ src/lightning_model.py | 123 +++++ src/models/__init__.py | 0 src/models/conditioner.py | 26 + src/models/denoiser/__init__.py | 0 .../flatten_condit_encoder_catdecoder_fixt.py | 383 +++++++++++++++ ...flatten_condit_encoder_unetdecoder_fixt.py | 447 +++++++++++++++++ ...latten_condit_encoder_unetdecoder_fixt2.py | 448 +++++++++++++++++ ...ten_condit_encoder_unetdecoder_woy_fixt.py | 464 ++++++++++++++++++ src/models/denoiser/condit_dit.py | 274 +++++++++++ .../denoiser/flatten_condit_catdit_fixt.py | 314 ++++++++++++ .../denoiser/flatten_condit_conv_fixt.py | 340 +++++++++++++ .../denoiser/flatten_condit_convnext_fixt.py | 339 +++++++++++++ .../denoiser/flatten_condit_dit_fixt.py | 313 ++++++++++++ .../denoiser/flatten_condit_dit_norm_fixt.py | 314 ++++++++++++ .../flatten_condit_encoder_decoder_fixt.py | 429 ++++++++++++++++ .../denoiser/flatten_condit_mlp_fixt.py | 334 +++++++++++++ .../flatten_condit_sdown2_dit_fixt.py | 321 ++++++++++++ src/models/denoiser/flatten_dit_fixt.py | 306 ++++++++++++ src/models/denoiser/flatten_dit_fixt_xvout.py | 311 ++++++++++++ .../flatten_sharepatch_condit_dit_fixt.py | 308 ++++++++++++ src/models/denoiser/flowdcn.py | 160 ++++++ src/models/encoder.py | 132 +++++ src/models/vae.py | 81 +++ src/ops/cuda_kernels/backward.cu | 346 +++++++++++++ src/ops/cuda_kernels/bak_forward.cu | 289 +++++++++++ src/ops/cuda_kernels/forward.cu | 309 ++++++++++++ src/ops/cuda_kernels/forward.py | 95 ++++ src/ops/cuda_kernels/function.py | 126 +++++ src/ops/cuda_kernels/setup.py | 59 +++ src/ops/triton_kernels/__init__.py | 0 src/ops/triton_kernels/backward.py | 124 +++++ src/ops/triton_kernels/forward.py | 94 ++++ src/ops/triton_kernels/function.py | 48 ++ src/plugins/__init__.py | 0 src/plugins/bd_env.py | 70 +++ src/utils/__init__.py | 0 src/utils/copy.py | 13 + src/utils/model_loader.py | 29 ++ src/utils/no_grad.py | 16 + src/utils/patch_bugs.py | 17 + tools/cache_imlatent3.py | 117 +++++ tools/cache_imlatent4.py | 123 +++++ tools/cat_images.py | 43 ++ tools/classifer_training.py | 353 +++++++++++++ tools/debug_env.sh | 4 + tools/dino_scale.py | 173 +++++++ tools/dino_scale2.py | 168 +++++++ tools/dp.py | 64 +++ tools/figures/base++.py | 64 +++ tools/figures/base.py | 57 +++ tools/figures/cfg.py | 32 ++ tools/figures/feat_vis.py | 42 ++ tools/figures/large++.py | 63 +++ tools/figures/log_snr.py | 18 + tools/figures/output/base++_FID.pdf | Bin 0 -> 17051 bytes tools/figures/output/base++_FID50K.pdf | Bin 0 -> 17760 bytes .../figures/output/base++_InceptionScore.pdf | Bin 0 -> 17533 bytes tools/figures/output/base++_Precision.pdf | Bin 0 -> 17753 bytes tools/figures/output/base++_Recall.pdf | Bin 0 -> 17311 bytes tools/figures/output/base_FID.pdf | Bin 0 -> 16301 bytes tools/figures/output/base_InceptionScore.pdf | Bin 0 -> 16380 bytes tools/figures/output/base_Precision.pdf | Bin 0 -> 16251 bytes tools/figures/output/base_Recall.pdf | Bin 0 -> 16433 bytes tools/figures/output/cfg.pdf | Bin 0 -> 21820 bytes tools/figures/output/large++_FID.pdf | Bin 0 -> 17440 bytes tools/figures/output/large++_FID50K.pdf | Bin 0 -> 18148 bytes .../figures/output/large++_InceptionScore.pdf | Bin 0 -> 17675 bytes tools/figures/output/large++_Precision.pdf | Bin 0 -> 17321 bytes tools/figures/output/large++_Recall.pdf | Bin 0 -> 17490 bytes tools/figures/output/logsnr.pdf | Bin 0 -> 15346 bytes tools/figures/output/mean_sim.png | Bin 0 -> 20988 bytes tools/figures/output/sota.pdf | Bin 0 -> 20836 bytes tools/figures/output/timeshift.pdf | Bin 0 -> 18730 bytes tools/figures/output/timeshift_fid.pdf | Bin 0 -> 16406 bytes tools/figures/sota.py | 95 ++++ tools/figures/timeshift.py | 26 + tools/figures/timeshift_fid.py | 29 ++ tools/fm_images.py | 21 + tools/mm.py | 23 + tools/sigmoid.py | 20 + tools/vae2dino.py | 173 +++++++ tools/vis_timeshift.py | 23 + 145 files changed, 14400 insertions(+) create mode 100644 .idea/.gitignore create mode 100644 .idea/DDT.iml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml create mode 100644 configs/repa_flatten_condit22_fixt_xl.yaml create mode 100644 configs/repa_flatten_condit22_fixt_xl512.yaml create mode 100644 configs/repa_flatten_dit_fixt_large.yaml create mode 100644 configs/repa_flatten_dit_fixt_xl.yaml create mode 100644 main.py create mode 100644 main.sh create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/callbacks/__init__.py create mode 100644 src/callbacks/grad.py create mode 100644 src/callbacks/model_checkpoint.py create mode 100644 src/callbacks/save_images.py create mode 100644 src/callbacks/simple_ema.py create mode 100644 src/data/__init__.py create mode 100644 src/data/dataset/__init__.py create mode 100644 src/data/dataset/celeba.py create mode 100644 src/data/dataset/imagenet.py create mode 100644 src/data/dataset/metric_dataset.py create mode 100644 src/data/dataset/randn.py create mode 100644 src/data/var_training.py create mode 100644 src/diffusion/__init__.py create mode 100644 src/diffusion/base/guidance.py create mode 100644 src/diffusion/base/sampling.py create mode 100644 src/diffusion/base/scheduling.py create mode 100644 src/diffusion/base/training.py create mode 100644 src/diffusion/ddpm/ddim_sampling.py create mode 100644 src/diffusion/ddpm/scheduling.py create mode 100644 src/diffusion/ddpm/training.py create mode 100644 src/diffusion/ddpm/vp_sampling.py create mode 100644 src/diffusion/flow_matching/adam_sampling.py create mode 100644 src/diffusion/flow_matching/sampling.py create mode 100644 src/diffusion/flow_matching/scheduling.py create mode 100644 src/diffusion/flow_matching/training.py create mode 100644 src/diffusion/flow_matching/training_cos.py create mode 100644 src/diffusion/flow_matching/training_pyramid.py create mode 100644 src/diffusion/flow_matching/training_repa.py create mode 100644 src/diffusion/flow_matching/training_repa_mask.py create mode 100644 src/diffusion/pre_integral.py create mode 100644 src/diffusion/stateful_flow_matching/adam_sampling.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_adv.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_adv_x0.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_mask_repa.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_patch_adv.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_repa_jit.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_self_consistent.py create mode 100644 src/diffusion/stateful_flow_matching/bak/training_selflpips.py create mode 100644 src/diffusion/stateful_flow_matching/cm_sampling.py create mode 100644 src/diffusion/stateful_flow_matching/sampling.py create mode 100644 src/diffusion/stateful_flow_matching/scheduling.py create mode 100644 src/diffusion/stateful_flow_matching/sharing_sampling.py create mode 100644 src/diffusion/stateful_flow_matching/training.py create mode 100644 src/diffusion/stateful_flow_matching/training_adv.py create mode 100644 src/diffusion/stateful_flow_matching/training_distill_dino.py create mode 100644 src/diffusion/stateful_flow_matching/training_lpips.py create mode 100644 src/diffusion/stateful_flow_matching/training_lpips_lossweight.py create mode 100644 src/diffusion/stateful_flow_matching/training_repa.py create mode 100644 src/diffusion/stateful_flow_matching/training_repa_lpips.py create mode 100644 src/lightning_data.py create mode 100644 src/lightning_model.py create mode 100644 src/models/__init__.py create mode 100644 src/models/conditioner.py create mode 100644 src/models/denoiser/__init__.py create mode 100644 src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py create mode 100644 src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py create mode 100644 src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py create mode 100644 src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py create mode 100644 src/models/denoiser/condit_dit.py create mode 100644 src/models/denoiser/flatten_condit_catdit_fixt.py create mode 100644 src/models/denoiser/flatten_condit_conv_fixt.py create mode 100644 src/models/denoiser/flatten_condit_convnext_fixt.py create mode 100644 src/models/denoiser/flatten_condit_dit_fixt.py create mode 100644 src/models/denoiser/flatten_condit_dit_norm_fixt.py create mode 100644 src/models/denoiser/flatten_condit_encoder_decoder_fixt.py create mode 100644 src/models/denoiser/flatten_condit_mlp_fixt.py create mode 100644 src/models/denoiser/flatten_condit_sdown2_dit_fixt.py create mode 100644 src/models/denoiser/flatten_dit_fixt.py create mode 100644 src/models/denoiser/flatten_dit_fixt_xvout.py create mode 100644 src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py create mode 100644 src/models/denoiser/flowdcn.py create mode 100644 src/models/encoder.py create mode 100644 src/models/vae.py create mode 100644 src/ops/cuda_kernels/backward.cu create mode 100644 src/ops/cuda_kernels/bak_forward.cu create mode 100644 src/ops/cuda_kernels/forward.cu create mode 100644 src/ops/cuda_kernels/forward.py create mode 100644 src/ops/cuda_kernels/function.py create mode 100644 src/ops/cuda_kernels/setup.py create mode 100644 src/ops/triton_kernels/__init__.py create mode 100644 src/ops/triton_kernels/backward.py create mode 100644 src/ops/triton_kernels/forward.py create mode 100644 src/ops/triton_kernels/function.py create mode 100644 src/plugins/__init__.py create mode 100644 src/plugins/bd_env.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/copy.py create mode 100644 src/utils/model_loader.py create mode 100644 src/utils/no_grad.py create mode 100644 src/utils/patch_bugs.py create mode 100644 tools/cache_imlatent3.py create mode 100644 tools/cache_imlatent4.py create mode 100644 tools/cat_images.py create mode 100644 tools/classifer_training.py create mode 100644 tools/debug_env.sh create mode 100644 tools/dino_scale.py create mode 100644 tools/dino_scale2.py create mode 100644 tools/dp.py create mode 100644 tools/figures/base++.py create mode 100644 tools/figures/base.py create mode 100644 tools/figures/cfg.py create mode 100644 tools/figures/feat_vis.py create mode 100644 tools/figures/large++.py create mode 100644 tools/figures/log_snr.py create mode 100644 tools/figures/output/base++_FID.pdf create mode 100644 tools/figures/output/base++_FID50K.pdf create mode 100644 tools/figures/output/base++_InceptionScore.pdf create mode 100644 tools/figures/output/base++_Precision.pdf create mode 100644 tools/figures/output/base++_Recall.pdf create mode 100644 tools/figures/output/base_FID.pdf create mode 100644 tools/figures/output/base_InceptionScore.pdf create mode 100644 tools/figures/output/base_Precision.pdf create mode 100644 tools/figures/output/base_Recall.pdf create mode 100644 tools/figures/output/cfg.pdf create mode 100644 tools/figures/output/large++_FID.pdf create mode 100644 tools/figures/output/large++_FID50K.pdf create mode 100644 tools/figures/output/large++_InceptionScore.pdf create mode 100644 tools/figures/output/large++_Precision.pdf create mode 100644 tools/figures/output/large++_Recall.pdf create mode 100644 tools/figures/output/logsnr.pdf create mode 100644 tools/figures/output/mean_sim.png create mode 100644 tools/figures/output/sota.pdf create mode 100644 tools/figures/output/timeshift.pdf create mode 100644 tools/figures/output/timeshift_fid.pdf create mode 100644 tools/figures/sota.py create mode 100644 tools/figures/timeshift.py create mode 100644 tools/figures/timeshift_fid.py create mode 100644 tools/fm_images.py create mode 100644 tools/mm.py create mode 100644 tools/sigmoid.py create mode 100644 tools/vae2dino.py create mode 100644 tools/vis_timeshift.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/DDT.iml b/.idea/DDT.iml new file mode 100644 index 0000000..d0876a7 --- /dev/null +++ b/.idea/DDT.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..b19c6c8 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/configs/repa_flatten_condit22_fixt_xl.yaml b/configs/repa_flatten_condit22_fixt_xl.yaml new file mode 100644 index 0000000..6f3a6ce --- /dev/null +++ b/configs/repa_flatten_condit22_fixt_xl.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_cond_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 2.0 + timeshift: 1.5 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + last_step: 0.04 + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_flatten_condit22_fixt_xl512.yaml b/configs/repa_flatten_condit22_fixt_xl512.yaml new file mode 100644 index 0000000..ce59338 --- /dev/null +++ b/configs/repa_flatten_condit22_fixt_xl512.yaml @@ -0,0 +1,108 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 4000000 + val_check_interval: 4000000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_condit_dit_fixt.FlattenConDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_cond_blocks: 22 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 3.0 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + timeshift: 1.0 + last_step: 0.04 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: imagenet512 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 512 + train_batch_size: 16 + eval_max_num_instances: 50000 + pred_batch_size: 32 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 64 + - 64 \ No newline at end of file diff --git a/configs/repa_flatten_dit_fixt_large.yaml b/configs/repa_flatten_dit_fixt_large.yaml new file mode 100644 index 0000000..7dc9611 --- /dev/null +++ b/configs/repa_flatten_dit_fixt_large.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_dit_fixt_large +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1024 + num_blocks: 24 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/configs/repa_flatten_dit_fixt_xl.yaml b/configs/repa_flatten_dit_fixt_xl.yaml new file mode 100644 index 0000000..2bc8606 --- /dev/null +++ b/configs/repa_flatten_dit_fixt_xl.yaml @@ -0,0 +1,99 @@ +# lightning.pytorch==2.4.0 +seed_everything: true +tags: + exp: &exp repa_flatten_dit_fixt_xl +torch_hub_dir: /mnt/bn/wangshuai6/torch_hub +huggingface_cache_dir: null +trainer: + default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: universal_flow + name: *exp + num_sanity_val_steps: 0 + max_steps: 400000 + val_check_interval: 100000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + plugins: + - src.plugins.bd_env.BDEnvironment +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: true + weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/ + denoiser: + class_path: src.models.denoiser.flatten_dit_fixt.FlattenDiT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 16 + hidden_size: &hidden_dim 1152 + num_blocks: 28 + num_classes: 1000 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1000 + diffusion_trainer: + class_path: src.diffusion.flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 8 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.00 + scheduler: *scheduler + w_scheduler: src.diffusion.flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + step_fn: src.diffusion.flow_matching.sampling.sde_preserve_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-4 + weight_decay: 0.0 +data: + train_dataset: imagenet256 + train_root: /mnt/bn/wangshuai6/data/ImageNet/train + train_image_size: 256 + train_batch_size: 32 + eval_max_num_instances: 50000 + pred_batch_size: 64 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1000 + latent_shape: + - 4 + - 32 + - 32 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..85bf7cb --- /dev/null +++ b/main.py @@ -0,0 +1,86 @@ +import time +from typing import Any, Union + +import pylab as pl + +from src.utils.patch_bugs import * + +import os +import torch +from lightning import Trainer, LightningModule +from src.lightning_data import DataModule +from src.lightning_model import LightningModel +from lightning.pytorch.cli import LightningCLI, LightningArgumentParser, SaveConfigCallback + +import logging +logger = logging.getLogger("lightning.pytorch") +# log_path = os.path.join( f"log.txt") +# logger.addHandler(logging.FileHandler(log_path)) + +class ReWriteRootSaveConfigCallback(SaveConfigCallback): + def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + stamp = time.strftime('%y%m%d%H%M') + file_path = os.path.join(trainer.default_root_dir, f"config-{stage}-{stamp}.yaml") + self.parser.save( + self.config, file_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + + +class ReWriteRootDirCli(LightningCLI): + def before_instantiate_classes(self) -> None: + super().before_instantiate_classes() + config_trainer = self._get(self.config, "trainer", default={}) + + # predict path & logger check + if self.subcommand == "predict": + config_trainer.logger = None + + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + class TagsClass: + def __init__(self, exp:str): + ... + parser.add_class_arguments(TagsClass, nested_key="tags") + + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + super().add_default_arguments_to_parser(parser) + parser.add_argument("--torch_hub_dir", type=str, default=None, help=("torch hub dir"),) + parser.add_argument("--huggingface_cache_dir", type=str, default=None, help=("huggingface hub dir"),) + + def instantiate_trainer(self, **kwargs: Any) -> Trainer: + config_trainer = self._get(self.config_init, "trainer", default={}) + default_root_dir = config_trainer.get("default_root_dir", None) + + if default_root_dir is None: + default_root_dir = os.path.join(os.getcwd(), "workdirs") + + dirname = "" + for v, k in self._get(self.config, "tags", default={}).items(): + dirname += f"{v}_{k}" + default_root_dir = os.path.join(default_root_dir, dirname) + is_resume = self._get(self.config_init, "ckpt_path", default=None) + if os.path.exists(default_root_dir) and "debug" not in default_root_dir: + if os.listdir(default_root_dir) and self.subcommand != "predict" and not is_resume: + raise FileExistsError(f"{default_root_dir} already exists") + + config_trainer.default_root_dir = default_root_dir + trainer = super().instantiate_trainer(**kwargs) + if trainer.is_global_zero: + os.makedirs(default_root_dir, exist_ok=True) + return trainer + + def instantiate_classes(self) -> None: + torch_hub_dir = self._get(self.config, "torch_hub_dir") + huggingface_cache_dir = self._get(self.config, "huggingface_cache_dir") + if huggingface_cache_dir is not None: + os.environ["HUGGINGFACE_HUB_CACHE"] = huggingface_cache_dir + if torch_hub_dir is not None: + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + super().instantiate_classes() + +if __name__ == "__main__": + + cli = ReWriteRootDirCli(LightningModel, DataModule, + auto_configure_optimizers=False, + save_config_callback=ReWriteRootSaveConfigCallback, + save_config_kwargs={"overwrite": True}) \ No newline at end of file diff --git a/main.sh b/main.sh new file mode 100644 index 0000000..39f803e --- /dev/null +++ b/main.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export NCCL_HOSTID=${MY_POD_NAME} +export MASTER_ADDR=${ARNOLD_WORKER_0_HOST} +export MASTER_PORT=${ARNOLD_WORKER_0_PORT} +export NODE_RANK=${ARNOLD_ID} +export NUM_NODES=${ARNOLD_WORKER_NUM} + +python3 main.py fit -c $1 --trainer.num_nodes $NUM_NODES +# for pid in $(ps -ef | grep "yaml" | grep -v "grep" | awk '{print $2}'); do kill -9 $pid; done \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9061a84 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +lightning==2.5.0.post0 +omegaconf==2.3.0 +jsonargparse[signatures]>=4.27.7 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/callbacks/__init__.py b/src/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/callbacks/grad.py b/src/callbacks/grad.py new file mode 100644 index 0000000..f9155b6 --- /dev/null +++ b/src/callbacks/grad.py @@ -0,0 +1,22 @@ +import torch +import lightning.pytorch as pl +from lightning.pytorch.utilities import grad_norm +from torch.optim import Optimizer + +class GradientMonitor(pl.Callback): + """Logs the gradient norm""" + + def __init__(self, norm_type: int = 2): + norm_type = float(norm_type) + if norm_type <= 0: + raise ValueError(f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}") + self.norm_type = norm_type + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + optimizer: Optimizer + ) -> None: + norms = grad_norm(pl_module, norm_type=self.norm_type) + max_grad = torch.tensor([v for k, v in norms.items() if k != f"grad_{self.norm_type}_norm_total"]).max() + pl_module.log_dict({'train/grad/max': max_grad, 'train/grad/total': norms[f"grad_{self.norm_type}_norm_total"]}) \ No newline at end of file diff --git a/src/callbacks/model_checkpoint.py b/src/callbacks/model_checkpoint.py new file mode 100644 index 0000000..019454e --- /dev/null +++ b/src/callbacks/model_checkpoint.py @@ -0,0 +1,25 @@ +import os.path +from typing import Optional, Dict, Any + +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from soupsieve.util import lower + + +class CheckpointHook(ModelCheckpoint): + """Save checkpoint with only the incremental part of the model""" + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: + self.dirpath = trainer.default_root_dir + self.exception_ckpt_path = os.path.join(self.dirpath, "on_exception.pt") + pl_module.strict_loading = False + + def on_save_checkpoint( + self, trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any] + ) -> None: + del checkpoint["callbacks"] + + # def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + # if not "debug" in self.exception_ckpt_path: + # trainer.save_checkpoint(self.exception_ckpt_path) \ No newline at end of file diff --git a/src/callbacks/save_images.py b/src/callbacks/save_images.py new file mode 100644 index 0000000..c6cd32b --- /dev/null +++ b/src/callbacks/save_images.py @@ -0,0 +1,105 @@ +import lightning.pytorch as pl +from lightning.pytorch import Callback + + +import os.path +import numpy +from PIL import Image +from typing import Sequence, Any, Dict +from concurrent.futures import ThreadPoolExecutor + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from lightning_utilities.core.rank_zero import rank_zero_info + +def process_fn(image, path): + Image.fromarray(image).save(path) + +class SaveImagesHook(Callback): + def __init__(self, save_dir="val", max_save_num=0, compressed=True): + self.save_dir = save_dir + self.max_save_num = max_save_num + self.compressed = compressed + + def save_start(self, target_dir): + self.target_dir = target_dir + self.executor_pool = ThreadPoolExecutor(max_workers=8) + if not os.path.exists(self.target_dir): + os.makedirs(self.target_dir, exist_ok=True) + else: + if os.listdir(target_dir) and "debug" not in str(target_dir): + raise FileExistsError(f'{self.target_dir} already exists and not empty!') + self.samples = [] + self._have_saved_num = 0 + rank_zero_info(f"Save images to {self.target_dir}") + + def save_image(self, images, filenames): + images = images.permute(0, 2, 3, 1).cpu().numpy() + for sample, filename in zip(images, filenames): + if isinstance(filename, Sequence): + filename = filename[0] + path = f'{self.target_dir}/{filename}' + if self._have_saved_num >= self.max_save_num: + break + self.executor_pool.submit(process_fn, sample, path) + self._have_saved_num += 1 + + def process_batch( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: STEP_OUTPUT, + batch: Any, + ) -> None: + b, c, h, w = samples.shape + xT, y, metadata = batch + all_samples = pl_module.all_gather(samples).view(-1, c, h, w) + self.save_image(samples, metadata) + if trainer.is_global_zero: + all_samples = all_samples.permute(0, 2, 3, 1).cpu().numpy() + self.samples.append(all_samples) + + def save_end(self): + if self.compressed and len(self.samples) > 0: + samples = numpy.concatenate(self.samples) + numpy.savez(f'{self.target_dir}/output.npz', arr_0=samples) + self.executor_pool.shutdown(wait=True) + self.samples = [] + self.target_dir = None + self._have_saved_num = 0 + self.executor_pool = None + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, f"iter_{trainer.global_step}") + self.save_start(target_dir) + + def on_validation_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, outputs, batch) + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + target_dir = os.path.join(trainer.default_root_dir, self.save_dir, "predict") + self.save_start(target_dir) + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + samples: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + return self.process_batch(trainer, pl_module, samples, batch) + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.save_end() \ No newline at end of file diff --git a/src/callbacks/simple_ema.py b/src/callbacks/simple_ema.py new file mode 100644 index 0000000..28bf476 --- /dev/null +++ b/src/callbacks/simple_ema.py @@ -0,0 +1,79 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +import threading +import lightning.pytorch as pl +from lightning.pytorch import Callback +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from src.utils.copy import swap_tensors + +class SimpleEMA(Callback): + def __init__(self, net:nn.Module, ema_net:nn.Module, + decay: float = 0.9999, + every_n_steps: int = 1, + eval_original_model:bool = False + ): + super().__init__() + self.decay = decay + self.every_n_steps = every_n_steps + self.eval_original_model = eval_original_model + self._stream = torch.cuda.Stream() + + self.net_params = list(net.parameters()) + self.ema_params = list(ema_net.parameters()) + + def swap_model(self): + for ema_p, p, in zip(self.ema_params, self.net_params): + swap_tensors(ema_p, p) + + def ema_step(self): + @torch.no_grad() + def ema_update(ema_model_tuple, current_model_tuple, decay): + torch._foreach_mul_(ema_model_tuple, decay) + torch._foreach_add_( + ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ) + + if self._stream is not None: + self._stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._stream): + ema_update(self.ema_params, self.net_params, self.decay) + + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int + ) -> None: + if trainer.global_step % self.every_n_steps == 0: + self.ema_step() + + def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + def on_predict_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + if not self.eval_original_model: + self.swap_model() + + + def state_dict(self) -> Dict[str, Any]: + return { + "decay": self.decay, + "every_n_steps": self.every_n_steps, + "eval_original_model": self.eval_original_model, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.decay = state_dict["decay"] + self.every_n_steps = state_dict["every_n_steps"] + self.eval_original_model = state_dict["eval_original_model"] + diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1 @@ + diff --git a/src/data/dataset/__init__.py b/src/data/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/dataset/celeba.py b/src/data/dataset/celeba.py new file mode 100644 index 0000000..30f5d3f --- /dev/null +++ b/src/data/dataset/celeba.py @@ -0,0 +1,11 @@ +from typing import Callable +from torchvision.datasets import CelebA + + +class LocalDataset(CelebA): + def __init__(self, root:str, ): + super(LocalDataset, self).__init__(root, "train") + + def __getitem__(self, idx): + data = super().__getitem__(idx) + return data \ No newline at end of file diff --git a/src/data/dataset/imagenet.py b/src/data/dataset/imagenet.py new file mode 100644 index 0000000..59d0547 --- /dev/null +++ b/src/data/dataset/imagenet.py @@ -0,0 +1,82 @@ +import torch +from PIL import Image +from torchvision.datasets import ImageFolder +from torchvision.transforms.functional import to_tensor +from torchvision.transforms import Normalize + +from src.data.dataset.metric_dataset import CenterCrop + +class LocalCachedDataset(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.cache_root = None + + def load_latent(self, latent_path): + pk_data = torch.load(latent_path) + mean = pk_data['mean'].to(torch.float32) + logvar = pk_data['logvar'].to(torch.float32) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + latent = mean + torch.randn_like(mean) * std + return latent + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + latent_path = image_path.replace(self.root, self.cache_root) + ".pt" + + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + if self.cache_root is not None: + latent = self.load_latent(latent_path) + else: + latent = raw_image + return raw_image, latent, target + +class ImageNet256(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 256) + self.cache_root = root + "_256_latent" + +class ImageNet512(LocalCachedDataset): + def __init__(self, root, ): + super().__init__(root, 512) + self.cache_root = root + "_512_latent" + +class PixImageNet(ImageFolder): + def __init__(self, root, resolution=256): + super().__init__(root) + self.transform = CenterCrop(resolution) + self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + def __getitem__(self, idx: int): + image_path, target = self.samples[idx] + raw_image = Image.open(image_path).convert('RGB') + raw_image = self.transform(raw_image) + raw_image = to_tensor(raw_image) + + normalized_image = self.normalize(raw_image) + return raw_image, normalized_image, target + +class PixImageNet64(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 64) + +class PixImageNet128(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 128) + + +class PixImageNet256(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 256) + +class PixImageNet512(PixImageNet): + def __init__(self, root, ): + super().__init__(root, 512) + + + + + diff --git a/src/data/dataset/metric_dataset.py b/src/data/dataset/metric_dataset.py new file mode 100644 index 0000000..cbe7d66 --- /dev/null +++ b/src/data/dataset/metric_dataset.py @@ -0,0 +1,82 @@ +import pathlib + +import torch +import random +import numpy as np +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset + +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + + +from PIL import Image +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +def test_collate(batch): + return torch.stack(batch) + +class ImageDataset(Dataset): + def __init__(self, root, image_size=(224, 224)): + self.root = pathlib.Path(root) + images = [] + for ext in IMG_EXTENSIONS: + images.extend(self.root.rglob(ext)) + random.shuffle(images) + self.images = list(map(lambda x: str(x), images)) + self.transform = tvtf.Compose( + [ + CenterCrop(image_size[0]), + tvtf.ToTensor(), + tvtf.Lambda(lambda x: (x*255).to(torch.uint8)), + tvtf.Lambda(lambda x: x.expand(3, -1, -1)) + ] + ) + self.size = image_size + + def __getitem__(self, idx): + try: + image = Image.open(self.images[idx]) + image = self.transform(image) + except Exception as e: + print(self.images[idx]) + image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8) + + # print(image) + metadata = dict( + path = self.images[idx], + root = self.root, + ) + return image #, metadata + + def __len__(self): + return len(self.images) \ No newline at end of file diff --git a/src/data/dataset/randn.py b/src/data/dataset/randn.py new file mode 100644 index 0000000..f9ec772 --- /dev/null +++ b/src/data/dataset/randn.py @@ -0,0 +1,41 @@ +import os.path +import random + +import torch +from torch.utils.data import Dataset + + + +class RandomNDataset(Dataset): + def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ): + self.selected_classes = selected_classes + if selected_classes is not None: + num_classes = len(selected_classes) + max_num_instances = 10*num_classes + self.num_classes = num_classes + self.seeds = seeds + if seeds is not None: + self.max_num_instances = len(seeds)*num_classes + self.num_seeds = len(seeds) + else: + self.num_seeds = (max_num_instances + num_classes - 1) // num_classes + self.max_num_instances = self.num_seeds*num_classes + + self.latent_shape = latent_shape + + + def __getitem__(self, idx): + label = idx // self.num_seeds + if self.selected_classes: + label = self.selected_classes[label] + seed = random.randint(0, 1<<31) #idx % self.num_seeds + if self.seeds is not None: + seed = self.seeds[idx % self.num_seeds] + + # cls_dir = os.path.join(self.root, f"{label}") + filename = f"{label}_{seed}.png", + generator = torch.Generator().manual_seed(seed) + latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32) + return latent, label, filename + def __len__(self): + return self.max_num_instances \ No newline at end of file diff --git a/src/data/var_training.py b/src/data/var_training.py new file mode 100644 index 0000000..de7fb74 --- /dev/null +++ b/src/data/var_training.py @@ -0,0 +1,145 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor +from typing import List +from PIL import Image +import torch +import random +import numpy as np +import copy +import torchvision.transforms.functional as tvtf +from src.models.vae import uint82fp + + +def center_crop_arr(pil_image, width, height): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = max(width / pil_image.size[0], height / pil_image.size[1]) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + arr = np.array(pil_image) + crop_y = random.randint(0, (arr.shape[0] - height)) + crop_x = random.randint(0, (arr.shape[1] - width)) + return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width]) + +def process_fn(width, height, data, hflip=0.5): + image, label = data + if random.uniform(0, 1) > hflip: # hflip + image = tvtf.hflip(image) + image = center_crop_arr(image, width, height) # crop + image = np.array(image).transpose(2, 0, 1) + return image, label + +class VARCandidate: + def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024): + self.aspect_ratio = aspect_ratio + self.width = int(width) + self.height = int(height) + self.buffer = buffer + self.max_buffer_size = max_buffer_size + + def add_sample(self, data): + self.buffer.append(data) + self.buffer = self.buffer[-self.max_buffer_size:] + + def ready(self, batch_size): + return len(self.buffer) >= batch_size + + def get_batch(self, batch_size): + batch = self.buffer[:batch_size] + self.buffer = self.buffer[batch_size:] + batch = [copy.deepcopy(b.result()) for b in batch] + x, y = zip(*batch) + x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0) + x = list(map(uint82fp, x)) + return x, y + +class VARTransformEngine: + def __init__(self, + base_image_size, + num_aspect_ratios, + min_aspect_ratio, + max_aspect_ratio, + num_workers = 8, + ): + self.base_image_size = base_image_size + self.num_aspect_ratios = num_aspect_ratios + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios) + self.aspect_ratios = self.aspect_ratios.tolist() + self.candidates_pool = [] + for i in range(self.num_aspect_ratios): + candidate = VARCandidate( + aspect_ratio=self.aspect_ratios[i], + width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16), + height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16), + buffer=[], + max_buffer_size=1024 + ) + self.candidates_pool.append(candidate) + self.default_candidate = VARCandidate( + aspect_ratio=1.0, + width=self.base_image_size, + height=self.base_image_size, + buffer=[], + max_buffer_size=1024, + ) + self.executor_pool = ProcessPoolExecutor(max_workers=num_workers) + self._prefill_count = 100 + + def find_candidate(self, data): + image = data[0] + aspect_ratio = image.size[0] / image.size[1] + min_distance = 1000000 + min_candidate = None + for candidate in self.candidates_pool: + dis = abs(aspect_ratio - candidate.aspect_ratio) + if dis < min_distance: + min_distance = dis + min_candidate = candidate + return min_candidate + + + def __call__(self, batch_data): + self._prefill_count -= 1 + if isinstance(batch_data[0], torch.Tensor): + batch_data[0] = batch_data[0].unbind(0) + + batch_data = list(zip(*batch_data)) + for data in batch_data: + candidate = self.find_candidate(data) + future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data) + candidate.add_sample(future) + if self._prefill_count >= 0: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + + batch_size = len(batch_data) + random.shuffle(self.candidates_pool) + for candidate in self.candidates_pool: + if candidate.ready(batch_size=batch_size): + return candidate.get_batch(batch_size=batch_size) + + # fallback to default 256 + for data in batch_data: + future = self.executor_pool.submit(process_fn, + self.default_candidate.width, + self.default_candidate.height, + data) + self.default_candidate.add_sample(future) + return self.default_candidate.get_batch(batch_size=batch_size) \ No newline at end of file diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/diffusion/base/guidance.py b/src/diffusion/base/guidance.py new file mode 100644 index 0000000..07b4754 --- /dev/null +++ b/src/diffusion/base/guidance.py @@ -0,0 +1,60 @@ +import torch + +def simple_guidance_fn(out, cfg): + uncondition, condtion = out.chunk(2, dim=0) + out = uncondition + cfg * (condtion - uncondition) + return out + +def c3_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3]) + return out + +def c4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p05_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.05 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p10_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.10 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p15_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.15 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def c4_p20_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condition = out.chunk(2, dim=0) + out = condition + out[:, :4] = uncondition[:, :4] + cfg * (condition[:, :4] - uncondition[:, :4]) + out[:, 4:] = uncondition[:, 4:] + 1.20 * (condition[:, 4:] - uncondition[:, 4:]) + return out + +def p4_guidance_fn(out, cfg): + # guidance function in DiT/SiT, seems like a bug not a feature? + uncondition, condtion = out.chunk(2, dim=0) + out = condtion + out[:, 4:] = uncondition[:, 4:] + cfg * (condtion[:, 4:] - uncondition[:, 4:]) + return out diff --git a/src/diffusion/base/sampling.py b/src/diffusion/base/sampling.py new file mode 100644 index 0000000..d8f9776 --- /dev/null +++ b/src/diffusion/base/sampling.py @@ -0,0 +1,31 @@ +from typing import Union, List + +import torch +import torch.nn as nn +from typing import Callable +from src.diffusion.base.scheduling import BaseScheduler + +class BaseSampler(nn.Module): + def __init__(self, + scheduler: BaseScheduler = None, + guidance_fn: Callable = None, + num_steps: int = 250, + guidance: Union[float, List[float]] = 1.0, + *args, + **kwargs + ): + super(BaseSampler, self).__init__() + self.num_steps = num_steps + self.guidance = guidance + self.guidance_fn = guidance_fn + self.scheduler = scheduler + + + def _impl_sampling(self, net, noise, condition, uncondition): + raise NotImplementedError + + def __call__(self, net, noise, condition, uncondition): + denoised = self._impl_sampling(net, noise, condition, uncondition) + return denoised + + diff --git a/src/diffusion/base/scheduling.py b/src/diffusion/base/scheduling.py new file mode 100644 index 0000000..05c7fb1 --- /dev/null +++ b/src/diffusion/base/scheduling.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor + +class BaseScheduler: + def alpha(self, t) -> Tensor: + ... + def sigma(self, t) -> Tensor: + ... + + def dalpha(self, t) -> Tensor: + ... + def dsigma(self, t) -> Tensor: + ... + + def dalpha_over_alpha(self, t) -> Tensor: + return self.dalpha(t) / self.alpha(t) + + def dsigma_mul_sigma(self, t) -> Tensor: + return self.dsigma(t)*self.sigma(t) + + def drift_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dalpha/(alpha + 1e-6) + + def diffuse_coefficient(self, t): + alpha, sigma = self.alpha(t), self.sigma(t) + dalpha, dsigma = self.dalpha(t), self.dsigma(t) + return dsigma*sigma - dalpha/(alpha + 1e-6)*sigma**2 + + def w(self, t): + return self.sigma(t) diff --git a/src/diffusion/base/training.py b/src/diffusion/base/training.py new file mode 100644 index 0000000..8f6d0e0 --- /dev/null +++ b/src/diffusion/base/training.py @@ -0,0 +1,29 @@ +import time + +import torch +import torch.nn as nn + +class BaseTrainer(nn.Module): + def __init__(self, + null_condition_p=0.1, + log_var=False, + ): + super(BaseTrainer, self).__init__() + self.null_condition_p = null_condition_p + self.log_var = log_var + + def preproprocess(self, raw_iamges, x, condition, uncondition): + bsz = x.shape[0] + if self.null_condition_p > 0: + mask = torch.rand((bsz), device=condition.device) < self.null_condition_p + mask = mask.expand_as(condition) + condition[mask] = uncondition[mask] + return raw_iamges, x, condition + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + raise NotImplementedError + + def __call__(self, net, ema_net, raw_images, x, condition, uncondition): + raw_images, x, condition = self.preproprocess(raw_images, x, condition, uncondition) + return self._impl_trainstep(net, ema_net, raw_images, x, condition) + diff --git a/src/diffusion/ddpm/ddim_sampling.py b/src/diffusion/ddpm/ddim_sampling.py new file mode 100644 index 0000000..0db2a1d --- /dev/null +++ b/src/diffusion/ddpm/ddim_sampling.py @@ -0,0 +1,40 @@ +import torch +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + +import logging +logger = logging.getLogger(__name__) + +class DDIMSampler(BaseSampler): + def __init__( + self, + train_num_steps=1000, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.train_num_steps = train_num_steps + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(0.0, self.train_num_steps-1, self.num_steps, device=noise.device) + steps = torch.flip(steps, dims=[0]) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + t_cur = t_cur.repeat(batch_size) + t_next = t_next.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha = self.scheduler.alpha(t_cur) + sigma_next = self.scheduler.sigma(t_next) + alpha_next = self.scheduler.alpha(t_next) + cfg_x = torch.cat([x, x], dim=0) + t = t_cur.repeat(2) + out = net(cfg_x, t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + x0 = (x - sigma * out) / alpha + x = alpha_next * x0 + sigma_next * out + return x0 \ No newline at end of file diff --git a/src/diffusion/ddpm/scheduling.py b/src/diffusion/ddpm/scheduling.py new file mode 100644 index 0000000..aff1523 --- /dev/null +++ b/src/diffusion/ddpm/scheduling.py @@ -0,0 +1,102 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class DDPMScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.0001, + beta_max=0.02, + num_steps=1000, + ): + super().__init__() + self.beta_min = beta_min + self.beta_max = beta_max + self.num_steps = num_steps + + self.betas_table = torch.linspace(self.beta_min, self.beta_max, self.num_steps, device="cuda") + self.alphas_table = torch.cumprod(1-self.betas_table, dim=0) + self.sigmas_table = 1-self.alphas_table + + + def beta(self, t) -> Tensor: + t = t.to(torch.long) + return self.betas_table[t].view(-1, 1, 1, 1) + + def alpha(self, t) -> Tensor: + t = t.to(torch.long) + return self.alphas_table[t].view(-1, 1, 1, 1)**0.5 + + def sigma(self, t) -> Tensor: + t = t.to(torch.long) + return self.sigmas_table[t].view(-1, 1, 1, 1)**0.5 + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + raise NotImplementedError("wrong usage") + + +class VPScheduler(BaseScheduler): + def __init__( + self, + beta_min=0.1, + beta_max=20, + ): + super().__init__() + self.beta_min = beta_min + self.beta_d = beta_max - beta_min + def beta(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + return (self.beta_min + (self.beta_d * t)).view(-1, 1, 1, 1) + + def sigma(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta:Tensor = 0.5*self.beta_d*t**2 + self.beta_min* t + return (1-torch.exp_(-inter_beta)).sqrt().view(-1, 1, 1, 1) + + def dsigma(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def dalpha_over_alpha(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dsigma_mul_sigma(self, t) ->Tensor: + raise NotImplementedError("wrong usage") + + def dalpha(self, t) -> Tensor: + raise NotImplementedError("wrong usage") + + def alpha(self, t) -> Tensor: + t = torch.clamp(t, min=1e-3, max=1) + inter_beta: Tensor = 0.5 * self.beta_d * t ** 2 + self.beta_min * t + return torch.exp(-0.5*inter_beta).view(-1, 1, 1, 1) + + def drift_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def diffuse_coefficient(self, t): + raise NotImplementedError("wrong usage") + + def w(self, t): + return self.diffuse_coefficient(t) + + + diff --git a/src/diffusion/ddpm/training.py b/src/diffusion/ddpm/training.py new file mode 100644 index 0000000..3e0d0ec --- /dev/null +++ b/src/diffusion/ddpm/training.py @@ -0,0 +1,83 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class VPTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t*self.train_max_t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - noise)**2 + + out = dict( + loss=loss.mean(), + ) + return out + + +class DDPMTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn: Callable = constant, + train_max_t=1000, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.train_max_t = train_max_t + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + t = torch.randint(0, self.train_max_t, (batch_size,)) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + x_t = alpha * x + noise * sigma + out = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight * (out - noise) ** 2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/ddpm/vp_sampling.py b/src/diffusion/ddpm/vp_sampling.py new file mode 100644 index 0000000..250b32d --- /dev/null +++ b/src/diffusion/ddpm/vp_sampling.py @@ -0,0 +1,59 @@ +import torch + +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * +from typing import Callable + +def ode_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + 0.5*eps*beta/sigma)*dt + +def sde_step_fn(x, eps, beta, sigma, dt): + return x + (-0.5*beta*x + eps*beta/sigma)*dt + torch.sqrt(dt.abs()*beta)*torch.randn_like(x) + +import logging +logger = logging.getLogger(__name__) + +class VPEulerSampler(BaseSampler): + def __init__( + self, + train_max_t=1000, + guidance_fn: Callable = None, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.guidance_fn = guidance_fn + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.train_max_t = train_max_t + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + + def _impl_sampling(self, net, noise, condition, uncondition): + batch_size = noise.shape[0] + steps = torch.linspace(1.0, self.last_step, self.num_steps, device=noise.device) + steps = torch.cat([steps, torch.tensor([0.0], device=noise.device)], dim=0) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + beta = self.scheduler.beta(t_cur) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t*self.train_max_t, cfg_condition) + eps = self.guidance_fn(out, self.guidance) + if i < self.num_steps -1 : + x0 = self.last_step_fn(x, eps, beta, sigma, -t_cur[0]) + x = self.step_fn(x, eps, beta, sigma, dt) + else: + x = x0 = self.last_step_fn(x, eps, beta, sigma, -self.last_step) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/adam_sampling.py b/src/diffusion/flow_matching/adam_sampling.py new file mode 100644 index 0000000..15d0c78 --- /dev/null +++ b/src/diffusion/flow_matching/adam_sampling.py @@ -0,0 +1,107 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/sampling.py b/src/diffusion/flow_matching/sampling.py new file mode 100644 index 0000000..62bdd8b --- /dev/null +++ b/src/diffusion/flow_matching/sampling.py @@ -0,0 +1,179 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + out = net(cfg_x, cfg_t, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x + + +class HeunSampler(BaseSampler): + def __init__( + self, + scheduler: BaseScheduler = None, + w_scheduler: BaseScheduler = None, + exact_henu=False, + timeshift=1.0, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.scheduler = scheduler + self.exact_henu = exact_henu + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Henu sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + v_hat, s_hat = 0.0, 0.0 + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + t_hat = t_next + t_hat = t_hat.repeat(batch_size) + sigma_hat = self.scheduler.sigma(t_hat) + alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat) + dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat) + + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + if i == 0 or self.exact_henu: + cfg_x = torch.cat([x, x], dim=0) + cfg_t_cur = t_cur.repeat(2) + out = net(cfg_x, cfg_t_cur, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v = out + s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma) + else: + v = v_hat + s = s_hat + x_hat = self.step_fn(x, v, dt, s=s, w=w) + # henu correct + if i < self.num_steps -1: + cfg_x_hat = torch.cat([x_hat, x_hat], dim=0) + cfg_t_hat = t_hat.repeat(2) + out = net(cfg_x_hat, cfg_t_hat, cfg_condition) + out = self.guidance_fn(out, self.guidance) + v_hat = out + s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat) + v = (v + v_hat) / 2 + s = (s + s_hat) / 2 + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/flow_matching/scheduling.py b/src/diffusion/flow_matching/scheduling.py new file mode 100644 index 0000000..a82cd3a --- /dev/null +++ b/src/diffusion/flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/flow_matching/training.py b/src/diffusion/flow_matching/training.py new file mode 100644 index 0000000..55c964d --- /dev/null +++ b/src/diffusion/flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_cos.py b/src/diffusion/flow_matching/training_cos.py new file mode 100644 index 0000000..aff30a7 --- /dev/null +++ b/src/diffusion/flow_matching/training_cos.py @@ -0,0 +1,59 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class COSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + fm_loss = weight*(out - v_t)**2 + cos_sim = torch.nn.functional.cosine_similarity(out, v_t, dim=1) + cos_loss = 1 - cos_sim + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + cos_loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/flow_matching/training_pyramid.py b/src/diffusion/flow_matching/training_pyramid.py new file mode 100644 index 0000000..be2bd94 --- /dev/null +++ b/src/diffusion/flow_matching/training_pyramid.py @@ -0,0 +1,68 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class PyramidTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + + output_pyramid = [] + def feature_hook(module, input, output): + output_pyramid.extend(output) + handle = net.decoder.register_forward_hook(feature_hook) + net(x_t, t, y) + handle.remove() + + loss = 0.0 + out_dict = dict() + + cur_v_t = v_t + for i in range(len(output_pyramid)): + cur_out = output_pyramid[i] + loss_i = (cur_v_t - cur_out) ** 2 + loss += loss_i.mean() + out_dict["loss_{}".format(i)] = loss_i.mean() + cur_v_t = torch.nn.functional.interpolate(cur_v_t, scale_factor=0.5, mode='bilinear', align_corners=False) + out_dict["loss"] = loss + return out_dict + diff --git a/src/diffusion/flow_matching/training_repa.py b/src/diffusion/flow_matching/training_repa.py new file mode 100644 index 0000000..e9a6788 --- /dev/null +++ b/src/diffusion/flow_matching/training_repa.py @@ -0,0 +1,142 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/flow_matching/training_repa_mask.py b/src/diffusion/flow_matching/training_repa_mask.py new file mode 100644 index 0000000..f8c4edb --- /dev/null +++ b/src/diffusion/flow_matching/training_repa_mask.py @@ -0,0 +1,152 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + mask_ratio=0.0, + mask_patch_size=2, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.mask_ratio = mask_ratio + self.mask_patch_size = mask_patch_size + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + patch_mask = torch.rand((batch_size, 1, height//self.mask_patch_size, width//self.mask_patch_size), device=x.device) + patch_mask = (patch_mask < self.mask_ratio).float() + mask = torch.nn.functional.interpolate(patch_mask, size=(height, width), mode='nearest') + masked_x = x*(1-mask)# + torch.randn_like(x)*(mask) + + x_t = alpha*masked_x + sigma*noise + v_t = dalpha*x + dsigma*noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + v_t_out, x0_out = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = (1-mask)*weight*(v_t_out - v_t)**2/(1-mask.mean()) + mask_loss = mask*weight*(x0_out - x)**2/(mask.mean()) + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + mask_loss=mask_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + mask_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/pre_integral.py b/src/diffusion/pre_integral.py new file mode 100644 index 0000000..848533a --- /dev/null +++ b/src/diffusion/pre_integral.py @@ -0,0 +1,143 @@ +import torch + +# lagrange interpolation +def lagrange_preint_o1(t1, v1, int_t_start, int_t_end): + ''' + lagrange interpolation of order 1 + Args: + t1: timestepx + v1: value field at t1 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = (int_t_end-int_t_start) + return int1*v1, (int1/int1, ) + +def lagrange_preint_o2(t1, t2, v1, v2, int_t_start, int_t_end): + ''' + lagrange interpolation of order 2 + Args: + t1: timestepx + t2: timestepy + v1: value field at t1 + v2: value field at t2 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1 = 0.5/(t1-t2)*((int_t_end-t2)**2 - (int_t_start-t2)**2) + int2 = 0.5/(t2-t1)*((int_t_end-t1)**2 - (int_t_start-t1)**2) + int_sum = int1+int2 + return int1*v1 + int2*v2, (int1/int_sum, int2/int_sum) + +def lagrange_preint_o3(t1, t2, t3, v1, v2, v3, int_t_start, int_t_end): + ''' + lagrange interpolation of order 3 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3) + int1_end = 1/3*(int_t_end)**3 - 1/2*(t2+t3)*(int_t_end)**2 + (t2*t3)*int_t_end + int1_start = 1/3*(int_t_start)**3 - 1/2*(t2+t3)*(int_t_start)**2 + (t2*t3)*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3) + int2_end = 1/3*(int_t_end)**3 - 1/2*(t1+t3)*(int_t_end)**2 + (t1*t3)*int_t_end + int2_start = 1/3*(int_t_start)**3 - 1/2*(t1+t3)*(int_t_start)**2 + (t1*t3)*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2) + int3_end = 1/3*(int_t_end)**3 - 1/2*(t1+t2)*(int_t_end)**2 + (t1*t2)*int_t_end + int3_start = 1/3*(int_t_start)**3 - 1/2*(t1+t2)*(int_t_start)**2 + (t1*t2)*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int_sum = int1+int2+int3 + return int1*v1 + int2*v2 + int3*v3, (int1/int_sum, int2/int_sum, int3/int_sum) + +def larange_preint_o4(t1, t2, t3, t4, v1, v2, v3, v4, int_t_start, int_t_end): + ''' + lagrange interpolation of order 4 + Args: + t1: timestepx + t2: timestepy + t3: timestepz + t4: timestepw + v1: value field at t1 + v2: value field at t2 + v3: value field at t3 + v4: value field at t4 + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + int1_denom = (t1-t2)*(t1-t3)*(t1-t4) + int1_end = 1/4*(int_t_end)**4 - 1/3*(t2+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_end**2 - t2*t3*t4*int_t_end + int1_start = 1/4*(int_t_start)**4 - 1/3*(t2+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t2*t3 + t2*t4)*int_t_start**2 - t2*t3*t4*int_t_start + int1 = (int1_end - int1_start)/int1_denom + int2_denom = (t2-t1)*(t2-t3)*(t2-t4) + int2_end = 1/4*(int_t_end)**4 - 1/3*(t1+t3+t4)*(int_t_end)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_end**2 - t1*t3*t4*int_t_end + int2_start = 1/4*(int_t_start)**4 - 1/3*(t1+t3+t4)*(int_t_start)**3 + 1/2*(t3*t4 + t1*t3 + t1*t4)*int_t_start**2 - t1*t3*t4*int_t_start + int2 = (int2_end - int2_start)/int2_denom + int3_denom = (t3-t1)*(t3-t2)*(t3-t4) + int3_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t4)*(int_t_end)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_end**2 - t1*t2*t4*int_t_end + int3_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t4)*(int_t_start)**3 + 1/2*(t4*t2 + t1*t2 + t1*t4)*int_t_start**2 - t1*t2*t4*int_t_start + int3 = (int3_end - int3_start)/int3_denom + int4_denom = (t4-t1)*(t4-t2)*(t4-t3) + int4_end = 1/4*(int_t_end)**4 - 1/3*(t1+t2+t3)*(int_t_end)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_end**2 - t1*t2*t3*int_t_end + int4_start = 1/4*(int_t_start)**4 - 1/3*(t1+t2+t3)*(int_t_start)**3 + 1/2*(t3*t2 + t1*t2 + t1*t3)*int_t_start**2 - t1*t2*t3*int_t_start + int4 = (int4_end - int4_start)/int4_denom + int_sum = int1+int2+int3+int4 + return int1*v1 + int2*v2 + int3*v3 + int4*v4, (int1/int_sum, int2/int_sum, int3/int_sum, int4/int_sum) + + +def lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end): + ''' + lagrange interpolation + Args: + order: order of interpolation + pre_vs: value field at pre_ts + pre_ts: timesteps + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + order = min(order, len(pre_vs), len(pre_ts)) + if order == 1: + return lagrange_preint_o1(pre_ts[-1], pre_vs[-1], int_t_start, int_t_end) + elif order == 2: + return lagrange_preint_o2(pre_ts[-2], pre_ts[-1], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 3: + return lagrange_preint_o3(pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + elif order == 4: + return larange_preint_o4(pre_ts[-4], pre_ts[-3], pre_ts[-2], pre_ts[-1], pre_vs[-4], pre_vs[-3], pre_vs[-2], pre_vs[-1], int_t_start, int_t_end) + else: + raise ValueError('Invalid order') + + +def polynomial_integral(coeffs, int_t_start, int_t_end): + ''' + polynomial integral + Args: + coeffs: coefficients of the polynomial + int_t_start: intergation start time + int_t_end: intergation end time + Returns: + integrated value + ''' + orders = len(coeffs) + int_val = 0 + for o in range(orders): + int_val += coeffs[o]/(o+1)*(int_t_end**(o+1)-int_t_start**(o+1)) + return int_val + diff --git a/src/diffusion/stateful_flow_matching/adam_sampling.py b/src/diffusion/stateful_flow_matching/adam_sampling.py new file mode 100644 index 0000000..fb2e95b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/adam_sampling.py @@ -0,0 +1,112 @@ +import math +from src.diffusion.base.sampling import * +from src.diffusion.base.scheduling import * +from src.diffusion.pre_integral import * + +from typing import Callable, List, Tuple + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def t2snr(t): + if isinstance(t, torch.Tensor): + return (t.clip(min=1e-8)/(1-t + 1e-8)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2snr(t) for t in t] + t = max(t, 1e-8) + return (t/(1-t + 1e-8)) + +def t2logsnr(t): + if isinstance(t, torch.Tensor): + return torch.log(t.clip(min=1e-3)/(1-t + 1e-3)) + if isinstance(t, List) or isinstance(t, Tuple): + return [t2logsnr(t) for t in t] + t = max(t, 1e-3) + return math.log(t/(1-t + 1e-3)) + +def t2isnr(t): + return 1/t2snr(t) + +def nop(t): + return t + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +import logging +logger = logging.getLogger(__name__) + +class AdamLMSampler(BaseSampler): + def __init__( + self, + order: int = 2, + timeshift: float = 1.0, + state_refresh_rate: int = 1, + lms_transform_fn: Callable = nop, + w_scheduler: BaseScheduler = None, + step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.w_scheduler = w_scheduler + self.state_refresh_rate = state_refresh_rate + + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + self.order = order + self.lms_transform_fn = lms_transform_fn + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, timeshift) + self.timedeltas = timesteps[1:] - self.timesteps[:-1] + self._reparameterize_coeffs() + + def _reparameterize_coeffs(self): + solver_coeffs = [[] for _ in range(self.num_steps)] + for i in range(0, self.num_steps): + pre_vs = [1.0, ]*(i+1) + pre_ts = self.lms_transform_fn(self.timesteps[:i+1]) + int_t_start = self.lms_transform_fn(self.timesteps[i]) + int_t_end = self.lms_transform_fn(self.timesteps[i+1]) + + order_annealing = self.order #self.num_steps - i + order = min(self.order, i + 1, order_annealing) + + _, coeffs = lagrange_preint(order, pre_vs, pre_ts, int_t_start, int_t_end) + solver_coeffs[i] = coeffs + self.solver_coeffs = solver_coeffs + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = x0 = noise + state = None + pred_trajectory = [] + t_cur = torch.zeros([batch_size,]).to(noise.device, noise.dtype) + timedeltas = self.timedeltas + solver_coeffs = self.solver_coeffs + for i in range(self.num_steps): + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + out = self.guidance_fn(out, self.guidances[i]) + pred_trajectory.append(out) + out = torch.zeros_like(out) + order = len(self.solver_coeffs[i]) + for j in range(order): + out += solver_coeffs[i][j] * pred_trajectory[-order:][j] + v = out + dt = timedeltas[i] + x0 = self.step_fn(x, v, 1-t_cur[0], s=0, w=0) + x = self.step_fn(x, v, dt, s=0, w=0) + t_cur += dt + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv.py b/src/diffusion/stateful_flow_matching/bak/training_adv.py new file mode 100644 index 0000000..4792950 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_adv.py @@ -0,0 +1,122 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + adv_feature = [] + def forward_hook(net, input, output): + adv_feature.append(output) + handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma + real_feature = adv_feature.pop() + net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) + fake_feature = adv_feature.pop() + handle.remove() + + + real_score_gan = self.dis_head(real_feature.detach()) + fake_score_gan = self.dis_head(fake_feature.detach()) + fake_score = self.dis_head(fake_feature) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py new file mode 100644 index 0000000..2843c04 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_adv_x0.py @@ -0,0 +1,127 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + lpips_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.lpips_weight = lpips_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = (x_t + out * sigma) + + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + with torch.no_grad(): + _, real_features = net(x, clean_t, y, classify_layer=self.adv_encoder_layer) + _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.adv_encoder_layer) + + real_score_gan = self.dis_head(real_features[-1].detach()) + fake_score_gan = self.dis_head(fake_features[-1].detach()) + fake_score = self.dis_head(fake_features[-1]) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean() + self.lpips_weight*lpips_loss.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py new file mode 100644 index 0000000..849ee4b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_mask_repa.py @@ -0,0 +1,159 @@ +import random + +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class MaskREPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + mask_groups=4, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.mask_groups = mask_groups + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def fetch_mask(self, length=256, groups=4, device=torch.device('cuda')): + mask = torch.zeros(1, length, length, device=device, dtype=torch.bool) + random_seq = torch.randperm(length, device=device) + for i in range(groups): + group_start = (length+groups-1)//groups*i + group_end = (length+groups-1)//groups*(i+1) + group_random_seq = random_seq[group_start:group_end] + y, x = torch.meshgrid(group_random_seq, group_random_seq) + mask[:, y, x] = True + return mask + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + mask_groups = random.randint(1, self.mask_groups) + mask = self.fetch_mask(length=256, groups=mask_groups, device=x.device) + out, _ = net(x_t, t, y, mask=mask) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py new file mode 100644 index 0000000..229680c --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_patch_adv.py @@ -0,0 +1,179 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t, mul=1000): + t_freq = self.timestep_embedding(t * mul, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class BatchNormWithTimeEmbedding(nn.Module): + def __init__(self, num_features): + super().__init__() + # self.bn = nn.BatchNorm2d(num_features, affine=False) + self.bn = nn.GroupNorm(16, num_features, affine=False) + # self.bn = nn.SyncBatchNorm(num_features, affine=False) + self.embedder = TimestepEmbedder(num_features * 2) + # nn.init.zeros_(self.embedder.mlp[-1].weight) + nn.init.trunc_normal_(self.embedder.mlp[-1].weight, std=0.01) + nn.init.zeros_(self.embedder.mlp[-1].bias) + + def forward(self, x, t): + embed = self.embedder(t) + embed = embed[:, :, None, None] + gamma, beta = embed.chunk(2, dim=1) + gamma = 1.0 + gamma + normed = self.bn(x) + out = normed * gamma + beta + return out + +class DisBlock(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.conv = nn.Conv2d( + kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=4, padding=0 + ) + self.norm = BatchNormWithTimeEmbedding(hidden_size) + self.act = nn.SiLU() + def forward(self, x, t): + x = self.conv(x) + x = self.norm(x, t) + x = self.act(x) + return x + + +class Discriminator(nn.Module): + def __init__(self, num_blocks, in_channels, hidden_size): + super().__init__() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + DisBlock( + in_channels=in_channels, + hidden_size=hidden_size, + ) + ) + in_channels = hidden_size + self.classifier = nn.Conv2d( + kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=1 + ) + def forward(self, feature, t): + B, C, H, W = feature.shape + for block in self.blocks: + feature = block(feature, t) + out = self.classifier(feature).view(B, -1) + out = out.sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_blocks=3, + adv_in_channels=3, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + + self.discriminator = Discriminator( + num_blocks=adv_blocks, + in_channels=adv_in_channels*2, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = x_t + sigma * out + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + real_feature = torch.cat([x_t, x], dim=1) + fake_feature = torch.cat([x_t, pred_x0], dim=1) + + real_score_gan = self.discriminator(real_feature.detach(), t) + fake_score_gan = self.discriminator(fake_feature.detach(), t) + fake_score = self.discriminator(fake_feature, t) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py new file mode 100644 index 0000000..e84e81f --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_repa_jit.py @@ -0,0 +1,154 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPAJiTTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + jit_deltas=0.01, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.jit_deltas = jit_deltas + self.encoder = DINOv2(encoder_weight_path) + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + t2 = base_t + (torch.rand_like(base_t)-0.5) * self.jit_deltas + t2 = torch.clip(t2, 0, 1) + alpha = self.scheduler.alpha(t2) + dalpha = self.scheduler.dalpha(t2) + sigma = self.scheduler.sigma(t2) + dsigma = self.scheduler.dsigma(t2) + x_t2 = alpha * x + noise * sigma + v_t2 = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + _, s = net(x_t, t, y, only_s=True) + out, _ = net(x_t2, t2, y, s) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t2)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py new file mode 100644 index 0000000..d7e741d --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_self_consistent.py @@ -0,0 +1,90 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class SelfConsistentTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + lpips_encoder_layer=4, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_encoder_layer = lpips_encoder_layer + self.lpips_weight = lpips_weight + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + real_features = [] + def forward_hook(net, input, output): + real_features.append(output) + handles = [] + for i in range(self.lpips_encoder_layer): + handle = net.encoder.blocks[i].register_forward_hook(forward_hook) + handles.append(handle) + + out, _ = net(x_t, t, y) + + for handle in handles: + handle.remove() + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + noise * sigma + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + _, fake_features = net(pred_xt, t, y, classify_layer=self.lpips_encoder_layer) + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/bak/training_selflpips.py b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py new file mode 100644 index 0000000..580775b --- /dev/null +++ b/src/diffusion/stateful_flow_matching/bak/training_selflpips.py @@ -0,0 +1,81 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class SelfLPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + lpips_encoder_layer=4, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_encoder_layer = lpips_encoder_layer + self.lpips_weight = lpips_weight + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + clean_t = torch.full((batch_size,), 1.0).to(x.device, x.dtype) + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + out, _ = net(x_t, t, y) + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + noise * sigma + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + with torch.no_grad(): + _, real_features = net(x, clean_t, y, classify_layer=self.lpips_encoder_layer) + _, fake_features = net(pred_x0, clean_t, y, classify_layer=self.lpips_encoder_layer) + + + lpips_loss = [] + for r, f in zip(real_features, fake_features): + r = torch.nn.functional.normalize(r, dim=-1) + f = torch.nn.functional.normalize(f, dim=-1) + lpips_loss.append(torch.sum((r - f)**2, dim=-1).mean()) + lpips_loss = sum(lpips_loss) + + + out = dict( + lpips_loss=lpips_loss.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + self.lpips_weight*lpips_loss.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/cm_sampling.py b/src/diffusion/stateful_flow_matching/cm_sampling.py new file mode 100644 index 0000000..5254db5 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/cm_sampling.py @@ -0,0 +1,78 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + + +import logging +logger = logging.getLogger(__name__) + +class CMSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + last_step=None, + step_fn=None, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.last_step = last_step + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + cfg_t = t_cur.repeat(batch_size*2) + cfg_x = torch.cat([x, x], dim=0) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur > self.guidance_interval_min and t_cur < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + + x0 = x + v * (1-t_cur) + alpha_next = self.scheduler.alpha(t_next) + sigma_next = self.scheduler.sigma(t_next) + x = alpha_next * x0 + sigma_next * torch.randn_like(x) + # print(alpha_next, sigma_next) + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/sampling.py b/src/diffusion/stateful_flow_matching/sampling.py new file mode 100644 index 0000000..5fdfdb2 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sampling.py @@ -0,0 +1,103 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + +def sde_mean_step_fn(x, v, dt, s, w): + return x + v * dt + s * w * dt + +def sde_step_fn(x, v, dt, s, w): + return x + v*dt + s * w* dt + torch.sqrt(2*w*dt)*torch.randn_like(x) + +def sde_preserve_step_fn(x, v, dt, s, w): + return x + v*dt + 0.5*s*w* dt + torch.sqrt(w*dt)*torch.randn_like(x) + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + sigma = self.scheduler.sigma(t_cur) + dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur) + dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur) + if self.w_scheduler: + w = self.w_scheduler.w(t_cur) + else: + w = 0.0 + + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i % self.state_refresh_rate == 0: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma) + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=s, w=w) + else: + x = self.last_step_fn(x, v, dt, s=s, w=w) + return x \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/scheduling.py b/src/diffusion/stateful_flow_matching/scheduling.py new file mode 100644 index 0000000..a82cd3a --- /dev/null +++ b/src/diffusion/stateful_flow_matching/scheduling.py @@ -0,0 +1,39 @@ +import math +import torch +from src.diffusion.base.scheduling import * + + +class LinearScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return (t).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return (1-t).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return torch.full_like(t, 1.0).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.full_like(t, -1.0).view(-1, 1, 1, 1) + +# SoTA for ImageNet! +class GVPScheduler(BaseScheduler): + def alpha(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def sigma(self, t) -> Tensor: + return torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dalpha(self, t) -> Tensor: + return -torch.sin(t * (math.pi / 2)).view(-1, 1, 1, 1) + def dsigma(self, t) -> Tensor: + return torch.cos(t * (math.pi / 2)).view(-1, 1, 1, 1) + def w(self, t): + return torch.sin(t)**2 + +class ConstScheduler(BaseScheduler): + def w(self, t): + return torch.ones(1, 1, 1, 1).to(t.device, t.dtype) + +from src.diffusion.ddpm.scheduling import VPScheduler +class VPBetaScheduler(VPScheduler): + def w(self, t): + return self.beta(t).view(-1, 1, 1, 1) + + + diff --git a/src/diffusion/stateful_flow_matching/sharing_sampling.py b/src/diffusion/stateful_flow_matching/sharing_sampling.py new file mode 100644 index 0000000..f372028 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/sharing_sampling.py @@ -0,0 +1,149 @@ +import torch + +from src.diffusion.base.guidance import * +from src.diffusion.base.scheduling import * +from src.diffusion.base.sampling import * + +from typing import Callable + + +def shift_respace_fn(t, shift=3.0): + return t / (t + (1 - t) * shift) + +def ode_step_fn(x, v, dt, s, w): + return x + v * dt + + +import logging +logger = logging.getLogger(__name__) + +class EulerSampler(BaseSampler): + def __init__( + self, + w_scheduler: BaseScheduler = None, + timeshift=1.0, + guidance_interval_min: float = 0.0, + guidance_interval_max: float = 1.0, + state_refresh_rate=1, + step_fn: Callable = ode_step_fn, + last_step=None, + last_step_fn: Callable = ode_step_fn, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.step_fn = step_fn + self.last_step = last_step + self.last_step_fn = last_step_fn + self.w_scheduler = w_scheduler + self.timeshift = timeshift + self.state_refresh_rate = state_refresh_rate + self.guidance_interval_min = guidance_interval_min + self.guidance_interval_max = guidance_interval_max + + if self.last_step is None or self.num_steps == 1: + self.last_step = 1.0 / self.num_steps + + timesteps = torch.linspace(0.0, 1 - self.last_step, self.num_steps) + timesteps = torch.cat([timesteps, torch.tensor([1.0])], dim=0) + self.timesteps = shift_respace_fn(timesteps, self.timeshift) + + assert self.last_step > 0.0 + assert self.scheduler is not None + assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ] + if self.w_scheduler is not None: + if self.step_fn == ode_step_fn: + logger.warning("current sampler is ODE sampler, but w_scheduler is enabled") + + # init recompute + self.num_recompute_timesteps = int(self.num_steps / self.state_refresh_rate) + self.recompute_timesteps = list(range(self.num_steps)) + + def sharing_dp(self, net, noise, condition, uncondition): + _, C, H, W = noise.shape + B = 8 + template_noise = torch.randn((B, C, H, W), generator=torch.Generator("cuda").manual_seed(0), device=noise.device) + template_condition = torch.randint(0, 1000, (B,), generator=torch.Generator("cuda").manual_seed(0), device=condition.device) + template_uncondition = torch.full((B, ), 1000, device=condition.device) + _, state_list = self._impl_sampling(net, template_noise, template_condition, template_uncondition) + states = torch.stack(state_list) + N, B, L, C = states.shape + states = states.view(N, B*L, C ) + states = states.permute(1, 0, 2) + states = torch.nn.functional.normalize(states, dim=-1) + with torch.autocast(device_type="cuda", dtype=torch.float64): + sim = torch.bmm(states, states.transpose(1, 2)) + sim = torch.mean(sim, dim=0).cpu() + error_map = (1-sim).tolist() + + # init cum-error + for i in range(1, self.num_steps): + for j in range(0, i): + error_map[i][j] = error_map[i-1][j] + error_map[i][j] + + # init dp and force 0 start + C = [[0.0, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + P = [[-1, ] * (self.num_steps + 1) for _ in range(self.num_recompute_timesteps+1)] + for i in range(1, self.num_steps+1): + C[1][i] = error_map[i - 1][0] + P[1][i] = 0 + + # dp state + for step in range(2, self.num_recompute_timesteps+1): + for i in range(step, self.num_steps+1): + min_value = 99999 + min_index = -1 + for j in range(step-1, i): + value = C[step-1][j] + error_map[i-1][j] + if value < min_value: + min_value = value + min_index = j + C[step][i] = min_value + P[step][i] = min_index + + # trace back + timesteps = [self.num_steps,] + for i in range(self.num_recompute_timesteps, 0, -1): + idx = timesteps[-1] + timesteps.append(P[i][idx]) + timesteps.reverse() + + print("recompute timesteps solved by DP: ", timesteps) + return timesteps[:-1] + + def _impl_sampling(self, net, noise, condition, uncondition): + """ + sampling process of Euler sampler + - + """ + batch_size = noise.shape[0] + steps = self.timesteps.to(noise.device) + cfg_condition = torch.cat([uncondition, condition], dim=0) + x = noise + state = None + pooled_state_list = [] + for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])): + dt = t_next - t_cur + t_cur = t_cur.repeat(batch_size) + cfg_x = torch.cat([x, x], dim=0) + cfg_t = t_cur.repeat(2) + if i in self.recompute_timesteps: + state = None + out, state = net(cfg_x, cfg_t, cfg_condition, state) + if t_cur[0] > self.guidance_interval_min and t_cur[0] < self.guidance_interval_max: + out = self.guidance_fn(out, self.guidance) + else: + out = self.guidance_fn(out, 1.0) + v = out + if i < self.num_steps -1 : + x = self.step_fn(x, v, dt, s=0.0, w=0.0) + else: + x = self.last_step_fn(x, v, dt, s=0.0, w=0.0) + pooled_state_list.append(state) + return x, pooled_state_list + + def __call__(self, net, noise, condition, uncondition): + if len(self.recompute_timesteps) != self.num_recompute_timesteps: + self.recompute_timesteps = self.sharing_dp(net, noise, condition, uncondition) + denoised, _ = self._impl_sampling(net, noise, condition, uncondition) + return denoised \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training.py b/src/diffusion/stateful_flow_matching/training.py new file mode 100644 index 0000000..4c49e1e --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training.py @@ -0,0 +1,55 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class FlowMatchingTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + + weight = self.loss_weight_fn(alpha, sigma) + + loss = weight*(out - v_t)**2 + + out = dict( + loss=loss.mean(), + ) + return out \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_adv.py b/src/diffusion/stateful_flow_matching/training_adv.py new file mode 100644 index 0000000..4792950 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_adv.py @@ -0,0 +1,122 @@ +import torch +import math +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class Discriminator(nn.Module): + def __init__(self, in_channels, hidden_size): + super().__init__() + self.head = nn.Sequential( + nn.Conv2d(kernel_size=4, in_channels=in_channels, out_channels=hidden_size, stride=2, padding=1), # 16x16 -> 8x8 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1), # 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.Conv2d(kernel_size=4, in_channels=hidden_size, out_channels=hidden_size, stride=2, padding=1),# 8x8 -> 4x4 + nn.GroupNorm(num_groups=32, num_channels=hidden_size), + nn.SiLU(), + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(kernel_size=1, in_channels=hidden_size, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 + ) + + def forward(self, feature): + B, L, C = feature.shape + H = W = int(math.sqrt(L)) + feature = feature.permute(0, 2, 1) + feature = feature.view(B, C, H, W) + out = self.head(feature).sigmoid().clamp(0.01, 0.99) + return out + +class AdvTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + adv_weight=1.0, + adv_encoder_layer=4, + adv_in_channels=768, + adv_hidden_size=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.adv_weight = adv_weight + self.adv_encoder_layer = adv_encoder_layer + + self.dis_head = Discriminator( + in_channels=adv_in_channels, + hidden_size=adv_hidden_size, + ) + + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + adv_feature = [] + def forward_hook(net, input, output): + adv_feature.append(output) + handle = net.encoder.blocks[self.adv_encoder_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + pred_xt = alpha * pred_x0 + torch.randn_like(pred_x0) * sigma + real_feature = adv_feature.pop() + net(pred_xt, t, y, classify_layer=self.adv_encoder_layer) + fake_feature = adv_feature.pop() + handle.remove() + + + real_score_gan = self.dis_head(real_feature.detach()) + fake_score_gan = self.dis_head(fake_feature.detach()) + fake_score = self.dis_head(fake_feature) + + loss_gan = -torch.log(1 - fake_score_gan) - torch.log(real_score_gan) + acc_real = (real_score_gan > 0.5).float() + acc_fake = (fake_score_gan < 0.5).float() + loss_adv = -torch.log(fake_score) + loss_adv_hack = torch.log(fake_score_gan) + + out = dict( + adv_loss=loss_adv.mean(), + gan_loss=loss_gan.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + (loss_adv.mean() + loss_adv_hack.mean())*self.adv_weight + loss_gan.mean(), + acc_real=acc_real.mean(), + acc_fake=acc_fake.mean(), + ) + return out diff --git a/src/diffusion/stateful_flow_matching/training_distill_dino.py b/src/diffusion/stateful_flow_matching/training_distill_dino.py new file mode 100644 index 0000000..c6a2937 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_distill_dino.py @@ -0,0 +1,141 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bilinear') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class DistillDINOTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + sigma = self.scheduler.sigma(t) + + x_t = alpha * x + noise * sigma + + _, s = net(x_t, t, y) + src_feature = self.proj(s) + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + out = dict( + cos_loss=cos_loss.mean(), + loss=cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/training_lpips.py b/src/diffusion/stateful_flow_matching/training_lpips.py new file mode 100644 index 0000000..a3cd2a2 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_lpips.py @@ -0,0 +1,71 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class LPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + # self.lpips = torch.compile(self.lpips) + no_grad(self.lpips) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + weight = self.loss_weight_fn(alpha, sigma) + loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out*sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0*0.5, target_x0*0.5) + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + lpips.mean()*self.lpips_weight, + ) + return out + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py new file mode 100644 index 0000000..e0233ea --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_lpips_lossweight.py @@ -0,0 +1,74 @@ +import torch +from typing import Callable +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from src.utils.no_grad import no_grad +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + +class LPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + lognorm_t=False, + lpips_weight=1.0, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = False + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + # self.lpips = torch.compile(self.lpips) + no_grad(self.lpips) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size = x.shape[0] + if self.lognorm_t: + t = torch.randn(batch_size).to(x.device, x.dtype).sigmoid() + else: + t = torch.rand(batch_size).to(x.device, x.dtype) + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + w = self.scheduler.w(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + out, _ = net(x_t, t, y) + + fm_weight = t*(1-t)**2/0.25 + lpips_weight = t + + loss = (out - v_t)**2 * fm_weight[:, None, None, None] + + pred_x0 = (x_t + out*sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0*0.5, target_x0*0.5)*lpips_weight[:, None, None, None] + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=loss.mean(), + loss=loss.mean() + lpips.mean()*self.lpips_weight, + ) + return out + + def state_dict(self, *args, destination=None, prefix='', keep_vars=False): + return \ No newline at end of file diff --git a/src/diffusion/stateful_flow_matching/training_repa.py b/src/diffusion/stateful_flow_matching/training_repa.py new file mode 100644 index 0000000..a5a28db --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_repa.py @@ -0,0 +1,157 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPATrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + + if getattr(net, "blocks", None) is not None: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + out = dict( + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/diffusion/stateful_flow_matching/training_repa_lpips.py b/src/diffusion/stateful_flow_matching/training_repa_lpips.py new file mode 100644 index 0000000..5a11207 --- /dev/null +++ b/src/diffusion/stateful_flow_matching/training_repa_lpips.py @@ -0,0 +1,170 @@ +import torch +import copy +import timm +from torch.nn import Parameter + +from src.utils.no_grad import no_grad +from typing import Callable, Iterator, Tuple +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +from src.diffusion.base.training import * +from src.diffusion.base.scheduling import BaseScheduler +from torchmetrics.image.lpip import _NoTrainLpips + +def inverse_sigma(alpha, sigma): + return 1/sigma**2 +def snr(alpha, sigma): + return alpha/sigma +def minsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, min=threshold) +def maxsnr(alpha, sigma, threshold=5): + return torch.clip(alpha/sigma, max=threshold) +def constant(alpha, sigma): + return 1 + + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load( + '/mnt/bn/wangshuai6/torch_hub/facebookresearch_dinov2_main', + weight_path, + source="local", + skip_validation=True + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + return feature + + +class REPALPIPSTrainer(BaseTrainer): + def __init__( + self, + scheduler: BaseScheduler, + loss_weight_fn:Callable=constant, + feat_loss_weight: float=0.5, + lognorm_t=False, + lpips_weight=1.0, + encoder_weight_path=None, + align_layer=8, + proj_denoiser_dim=256, + proj_hidden_dim=256, + proj_encoder_dim=256, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self.lognorm_t = lognorm_t + self.scheduler = scheduler + self.loss_weight_fn = loss_weight_fn + self.feat_loss_weight = feat_loss_weight + self.align_layer = align_layer + self.encoder = DINOv2(encoder_weight_path) + self.proj_encoder_dim = proj_encoder_dim + no_grad(self.encoder) + + self.lpips_weight = lpips_weight + self.lpips = _NoTrainLpips(net="vgg") + self.lpips = self.lpips.to(torch.bfloat16) + no_grad(self.lpips) + + self.proj = nn.Sequential( + nn.Sequential( + nn.Linear(proj_denoiser_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_hidden_dim), + nn.SiLU(), + nn.Linear(proj_hidden_dim, proj_encoder_dim), + ) + ) + + def _impl_trainstep(self, net, ema_net, raw_images, x, y): + batch_size, c, height, width = x.shape + if self.lognorm_t: + base_t = torch.randn((batch_size), device=x.device, dtype=x.dtype).sigmoid() + else: + base_t = torch.rand((batch_size), device=x.device, dtype=x.dtype) + t = base_t + + noise = torch.randn_like(x) + alpha = self.scheduler.alpha(t) + dalpha = self.scheduler.dalpha(t) + sigma = self.scheduler.sigma(t) + dsigma = self.scheduler.dsigma(t) + + x_t = alpha * x + noise * sigma + v_t = dalpha * x + dsigma * noise + + src_feature = [] + def forward_hook(net, input, output): + src_feature.append(output) + if getattr(net, "blocks", None) is not None: + handle = net.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + else: + handle = net.encoder.blocks[self.align_layer - 1].register_forward_hook(forward_hook) + + out, _ = net(x_t, t, y) + src_feature = self.proj(src_feature[0]) + handle.remove() + + with torch.no_grad(): + dst_feature = self.encoder(raw_images) + + if dst_feature.shape[1] != src_feature.shape[1]: + dst_length = dst_feature.shape[1] + rescale_ratio = (src_feature.shape[1] / dst_feature.shape[1])**0.5 + dst_height = (dst_length)**0.5 * (height/width)**0.5 + dst_width = (dst_length)**0.5 * (width/height)**0.5 + dst_feature = dst_feature.view(batch_size, int(dst_height), int(dst_width), self.proj_encoder_dim) + dst_feature = dst_feature.permute(0, 3, 1, 2) + dst_feature = torch.nn.functional.interpolate(dst_feature, scale_factor=rescale_ratio, mode='bilinear', align_corners=False) + dst_feature = dst_feature.permute(0, 2, 3, 1) + dst_feature = dst_feature.view(batch_size, -1, self.proj_encoder_dim) + + cos_sim = torch.nn.functional.cosine_similarity(src_feature, dst_feature, dim=-1) + cos_loss = 1 - cos_sim + + weight = self.loss_weight_fn(alpha, sigma) + fm_loss = weight*(out - v_t)**2 + + pred_x0 = (x_t + out * sigma) + target_x0 = x + # fixbug lpips std + lpips = self.lpips(pred_x0 * 0.5, target_x0 * 0.5) + + out = dict( + lpips_loss=lpips.mean(), + fm_loss=fm_loss.mean(), + cos_loss=cos_loss.mean(), + loss=fm_loss.mean() + self.feat_loss_weight*cos_loss.mean() + self.lpips_weight*lpips.mean(), + ) + return out + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + self.proj.state_dict( + destination=destination, + prefix=prefix + "proj.", + keep_vars=keep_vars) + diff --git a/src/lightning_data.py b/src/lightning_data.py new file mode 100644 index 0000000..9f75a42 --- /dev/null +++ b/src/lightning_data.py @@ -0,0 +1,162 @@ +from typing import Any +import torch +import copy +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS +from torch.utils.data import DataLoader +from src.data.dataset.randn import RandomNDataset +from src.data.var_training import VARTransformEngine + +def collate_fn(batch): + new_batch = copy.deepcopy(batch) + new_batch = list(zip(*new_batch)) + for i in range(len(new_batch)): + if isinstance(new_batch[i][0], torch.Tensor): + try: + new_batch[i] = torch.stack(new_batch[i], dim=0) + except: + print("Warning: could not stack tensors") + return new_batch + +class DataModule(pl.LightningDataModule): + def __init__(self, + train_root, + test_nature_root, + test_gen_root, + train_image_size=64, + train_batch_size=64, + train_num_workers=8, + var_transform_engine: VARTransformEngine = None, + train_prefetch_factor=2, + train_dataset: str = None, + eval_batch_size=32, + eval_num_workers=4, + eval_max_num_instances=50000, + pred_batch_size=32, + pred_num_workers=4, + pred_seeds:str=None, + pred_selected_classes=None, + num_classes=1000, + latent_shape=(4,64,64), + ): + super().__init__() + pred_seeds = list(map(lambda x: int(x), pred_seeds.strip().split(","))) if pred_seeds is not None else None + + self.train_root = train_root + self.train_image_size = train_image_size + self.train_dataset = train_dataset + # stupid data_convert override, just to make nebular happy + self.train_batch_size = train_batch_size + self.train_num_workers = train_num_workers + self.train_prefetch_factor = train_prefetch_factor + + self.test_nature_root = test_nature_root + self.test_gen_root = test_gen_root + self.eval_max_num_instances = eval_max_num_instances + self.pred_seeds = pred_seeds + self.num_classes = num_classes + self.latent_shape = latent_shape + + self.eval_batch_size = eval_batch_size + self.pred_batch_size = pred_batch_size + + self.pred_num_workers = pred_num_workers + self.eval_num_workers = eval_num_workers + + self.pred_selected_classes = pred_selected_classes + + self._train_dataloader = None + self.var_transform_engine = var_transform_engine + + def setup(self, stage: str) -> None: + if stage == "fit": + assert self.train_dataset is not None + if self.train_dataset == "pix_imagenet64": + from src.data.dataset.imagenet import PixImageNet64 + self.train_dataset = PixImageNet64( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet128": + from src.data.dataset.imagenet import PixImageNet128 + self.train_dataset = PixImageNet128( + root=self.train_root, + ) + elif self.train_dataset == "imagenet256": + from src.data.dataset.imagenet import ImageNet256 + self.train_dataset = ImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet256": + from src.data.dataset.imagenet import PixImageNet256 + self.train_dataset = PixImageNet256( + root=self.train_root, + ) + elif self.train_dataset == "imagenet512": + from src.data.dataset.imagenet import ImageNet512 + self.train_dataset = ImageNet512( + root=self.train_root, + ) + elif self.train_dataset == "pix_imagenet512": + from src.data.dataset.imagenet import PixImageNet512 + self.train_dataset = PixImageNet512( + root=self.train_root, + ) + else: + raise NotImplementedError("no such dataset") + + def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + if self.var_transform_engine and self.trainer.training: + batch = self.var_transform_engine(batch) + return batch + + def train_dataloader(self) -> TRAIN_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.train_dataset, num_replicas=world_size, rank=global_rank, shuffle=True) + self._train_dataloader = DataLoader( + self.train_dataset, + self.train_batch_size, + timeout=6000, + num_workers=self.train_num_workers, + prefetch_factor=self.train_prefetch_factor, + sampler=sampler, + collate_fn=collate_fn, + ) + return self._train_dataloader + + def val_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.eval_dataset = RandomNDataset( + latent_shape=self.latent_shape, + num_classes=self.num_classes, + max_num_instances=self.eval_max_num_instances, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.eval_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.eval_dataset, self.eval_batch_size, + num_workers=self.eval_num_workers, + prefetch_factor=2, + collate_fn=collate_fn, + sampler=sampler + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + global_rank = self.trainer.global_rank + world_size = self.trainer.world_size + self.pred_dataset = RandomNDataset( + seeds= self.pred_seeds, + max_num_instances=50000, + num_classes=self.num_classes, + selected_classes=self.pred_selected_classes, + latent_shape=self.latent_shape, + ) + from torch.utils.data import DistributedSampler + sampler = DistributedSampler(self.pred_dataset, num_replicas=world_size, rank=global_rank, shuffle=False) + return DataLoader(self.pred_dataset, batch_size=self.pred_batch_size, + num_workers=self.pred_num_workers, + prefetch_factor=4, + collate_fn=collate_fn, + sampler=sampler + ) diff --git a/src/lightning_model.py b/src/lightning_model.py new file mode 100644 index 0000000..4602e82 --- /dev/null +++ b/src/lightning_model.py @@ -0,0 +1,123 @@ +from typing import Callable, Iterable, Any, Optional, Union, Sequence, Mapping, Dict +import os.path +import copy +import torch +import torch.nn as nn +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import OptimizerLRScheduler, STEP_OUTPUT +from torch.optim.lr_scheduler import LRScheduler +from torch.optim import Optimizer +from lightning.pytorch.callbacks import Callback + + +from src.models.vae import BaseVAE, fp2uint8 +from src.models.conditioner import BaseConditioner +from src.utils.model_loader import ModelLoader +from src.callbacks.simple_ema import SimpleEMA +from src.diffusion.base.sampling import BaseSampler +from src.diffusion.base.training import BaseTrainer +from src.utils.no_grad import no_grad, filter_nograd_tensors +from src.utils.copy import copy_params + +EMACallable = Callable[[nn.Module, nn.Module], SimpleEMA] +OptimizerCallable = Callable[[Iterable], Optimizer] +LRSchedulerCallable = Callable[[Optimizer], LRScheduler] + + +class LightningModel(pl.LightningModule): + def __init__(self, + vae: BaseVAE, + conditioner: BaseConditioner, + denoiser: nn.Module, + diffusion_trainer: BaseTrainer, + diffusion_sampler: BaseSampler, + ema_tracker: Optional[EMACallable] = None, + optimizer: OptimizerCallable = None, + lr_scheduler: LRSchedulerCallable = None, + ): + super().__init__() + self.vae = vae + self.conditioner = conditioner + self.denoiser = denoiser + self.ema_denoiser = copy.deepcopy(self.denoiser) + self.diffusion_sampler = diffusion_sampler + self.diffusion_trainer = diffusion_trainer + self.ema_tracker = ema_tracker + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + # self.model_loader = ModelLoader() + + self._strict_loading = False + + def configure_model(self) -> None: + self.trainer.strategy.barrier() + # self.denoiser = self.model_loader.load(self.denoiser) + copy_params(src_model=self.denoiser, dst_model=self.ema_denoiser) + + # self.denoiser = torch.compile(self.denoiser) + # disable grad for conditioner and vae + no_grad(self.conditioner) + no_grad(self.vae) + no_grad(self.diffusion_sampler) + no_grad(self.ema_denoiser) + + def configure_callbacks(self) -> Union[Sequence[Callback], Callback]: + ema_tracker = self.ema_tracker(self.denoiser, self.ema_denoiser) + return [ema_tracker] + + def configure_optimizers(self) -> OptimizerLRScheduler: + params_denoiser = filter_nograd_tensors(self.denoiser.parameters()) + params_trainer = filter_nograd_tensors(self.diffusion_trainer.parameters()) + optimizer: torch.optim.Optimizer = self.optimizer([*params_trainer, *params_denoiser]) + if self.lr_scheduler is None: + return dict( + optimizer=optimizer + ) + else: + lr_scheduler = self.lr_scheduler(optimizer) + return dict( + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + + def training_step(self, batch, batch_idx): + raw_images, x, y = batch + with torch.no_grad(): + x = self.vae.encode(x) + condition, uncondition = self.conditioner(y) + loss = self.diffusion_trainer(self.denoiser, self.ema_denoiser, raw_images, x, condition, uncondition) + self.log_dict(loss, prog_bar=True, on_step=True, sync_dist=False) + return loss["loss"] + + def predict_step(self, batch, batch_idx): + xT, y, metadata = batch + with torch.no_grad(): + condition, uncondition = self.conditioner(y) + # Sample images: + samples = self.diffusion_sampler(self.denoiser, xT, condition, uncondition) + samples = self.vae.decode(samples) + # fp32 -1,1 -> uint8 0,255 + samples = fp2uint8(samples) + return samples + + def validation_step(self, batch, batch_idx): + samples = self.predict_step(batch, batch_idx) + return samples + + def state_dict(self, *args, destination=None, prefix="", keep_vars=False): + if destination is None: + destination = {} + self._save_to_state_dict(destination, prefix, keep_vars) + self.denoiser.state_dict( + destination=destination, + prefix=prefix+"denoiser.", + keep_vars=keep_vars) + self.ema_denoiser.state_dict( + destination=destination, + prefix=prefix+"ema_denoiser.", + keep_vars=keep_vars) + self.diffusion_trainer.state_dict( + destination=destination, + prefix=prefix+"diffusion_trainer.", + keep_vars=keep_vars) + return destination \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/conditioner.py b/src/models/conditioner.py new file mode 100644 index 0000000..a68fad3 --- /dev/null +++ b/src/models/conditioner.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class BaseConditioner(nn.Module): + def __init__(self): + super(BaseConditioner, self).__init__() + + def _impl_condition(self, y): + ... + def _impl_uncondition(self, y): + ... + def __call__(self, y): + condition = self._impl_condition(y) + uncondition = self._impl_uncondition(y) + return condition, uncondition + +class LabelConditioner(BaseConditioner): + def __init__(self, null_class): + super().__init__() + self.null_condition = null_class + + def _impl_condition(self, y): + return torch.tensor(y).long().cuda() + + def _impl_uncondition(self, y): + return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda() \ No newline at end of file diff --git a/src/models/denoiser/__init__.py b/src/models/denoiser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py new file mode 100644 index 0000000..3581446 --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_catdecoder_fixt.py @@ -0,0 +1,383 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + x = torch.cat([x, s], dim=-1) + x = self.x_embedder(x) + for i in range(self.num_blocks): + x = self.blocks[i](x, c, pos, None) + x = self.final_layer(x, c) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, + stride=self.patch_size) + return x + + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder) + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None): + if s is None: + with torch.no_grad(): + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py new file mode 100644 index 0000000..733ce4a --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt.py @@ -0,0 +1,447 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + self.embed_proj = nn.Linear(hidden_dim, dim) + + def forward(self, x, c): + c = self.embed_proj(c)[:, :, None, None] + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = x * c + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + y = self.y_embedder(y).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x, c) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x, c) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder) + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py new file mode 100644 index 0000000..6e9adbc --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_fixt2.py @@ -0,0 +1,448 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + self.embed_proj = nn.Linear(hidden_dim, dim) + + def forward(self, x, c): + c = self.embed_proj(c)[:, :, None, None] + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = x * c + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + y = self.y_embedder(y).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + c = torch.nn.functional.silu(t + y) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x, c) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x, c) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder, "encoder.") + ModelLoader().load(decoder, "decoder.") + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py new file mode 100644 index 0000000..537078a --- /dev/null +++ b/src/models/denoiser/bak/flatten_condit_encoder_unetdecoder_woy_fixt.py @@ -0,0 +1,464 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +class ResBlock(nn.Module): + def __init__(self, dim:int, groups:int=8, hidden_dim:int=256): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, padding=1) + self.conv2 = nn.Conv2d(dim, dim, 3, padding=1) + self.norm1 = nn.GroupNorm(groups, dim) + self.norm2 = nn.GroupNorm(groups, dim) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = torch.nn.functional.silu(x) + x = self.conv2(x) + x = self.norm2(x) + x = torch.nn.functional.silu(x) + return residual + x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None, classify_layer=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + classify_feats = [] + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + if classify_layer is not None and i < classify_layer: + classify_feats.append(s) + if i == classify_layer - 1: + return _, classify_feats + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_mid_blocks=18, + num_res_blocks=[1, 1, 1], + num_res_channels=[64, 384, 768], + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_mid_blocks = num_mid_blocks + self.num_res_blocks = num_res_blocks + self.num_res_channels = num_res_channels + self.patch_size = 2**(len(num_res_blocks)) + + self.t_embedder = TimestepEmbedder(hidden_size) + + self.down_res_blocks = nn.ModuleList() + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.down_res_blocks.append( + nn.Conv2d(previous_channel, channels, kernel_size=2, stride=2, padding=0), + ) + self.down_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = [] + previous_channel = self.in_channels + for num, channels in zip(num_res_blocks, num_res_channels): + self.up_res_blocks.append( + nn.ConvTranspose2d(channels, previous_channel, kernel_size=2, stride=2, padding=0) + ) + self.up_res_blocks.extend([ResBlock(channels, hidden_dim=hidden_size) for _ in range(num)]) + previous_channel = channels + self.up_res_blocks = nn.ModuleList(self.up_res_blocks[::-1]) + + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_mid_blocks) + ]) + + self.initialize_weights() + self.precompute_pos = dict() + self.weight_path = weight_path + self.load_ema = load_ema + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + # Zero-out adaLN modulation layers in SiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + for block in self.down_res_blocks: + if isinstance(block, ResBlock): + nn.init.constant_(block.conv1.weight, 0) + nn.init.constant_(block.conv1.bias, 0) + nn.init.constant_(block.norm1.weight, 0) + nn.init.constant_(block.norm2.weight, 0) + nn.init.constant_(block.conv2.weight, 0) + nn.init.constant_(block.conv2.bias, 0) + + for block in self.up_res_blocks: + if isinstance(block, ResBlock): + nn.init.constant_(block.conv1.weight, 0) + nn.init.constant_(block.conv1.bias, 0) + nn.init.constant_(block.norm1.weight, 0) + nn.init.constant_(block.norm2.weight, 0) + nn.init.constant_(block.conv2.weight, 0) + nn.init.constant_(block.conv2.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + t = self.t_embedder(t.view(-1)).view(B, self.hidden_size) + s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + + residual = [] + for i, block in enumerate(self.down_res_blocks): + if isinstance(block, nn.Conv2d): + residual.append(x) + x = block(x) + else: + x = block(x) + + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = x.view(B, self.hidden_size, -1).transpose(1, 2) + mid_c = torch.nn.functional.silu(t[:, None, :] + s) + for i in range(self.num_mid_blocks): + x = self.blocks[i](x, mid_c, pos, None) + x = x.transpose(1, 2).view(B, self.hidden_size, H//self.patch_size, W//self.patch_size) + + residual[0] = 0.0 + for i, block in enumerate(self.up_res_blocks): + if isinstance(block, nn.ConvTranspose2d): + x = block(x) + residual.pop() + else: + x = block(x) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + ModelLoader().load(encoder, "encoder.") + ModelLoader().load(decoder, "decoder.") + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + def forward(self, x, t, y, s=None, classify_layer=None): + if s is None: + _, s = self.encoder(x, t, y, classify_layer=classify_layer) + if classify_layer is not None: + return None, s + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiT_jointtraining(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/condit_dit.py b/src/models/denoiser/condit_dit.py new file mode 100644 index 0000000..48d6b0e --- /dev/null +++ b/src/models/denoiser/condit_dit.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import math + +from numba.cuda.cudadrv.devicearray import lru_cache +from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + +from torch.nn.attention import SDPBackend, sdpa_kernel + +flex_attention = torch.compile(flex_attention) + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = False, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = self.norm(x) + x = modulate(x, shift, scale) + x = self.linear(x) + return x + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.fc1 = nn.Linear(dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, dim) + self.act = nn.GELU(approximate="tanh") + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale: float=16): + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + x_pos = x_pos.reshape(-1) + y_pos = y_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + freqs_cis = torch.cat([x_freqs.sin(), x_freqs.cos(), y_freqs.sin(), y_freqs.cos()], dim=1) + return freqs_cis + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + # import pdb; pdb.set_trace() + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q).to(q.dtype) + k = self.k_norm(k).to(k.dtype) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + # x = flex_attention(q, k, v, block_mask=mask) + # with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + x = scaled_dot_product_attention(q, k, v, mask) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class DiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size , elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=groups, qkv_bias=True, qk_norm=False) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class ConDiT(nn.Module): + def __init__( + self, + in_channels=4, + out_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels * patch_size ** 2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes + 1, hidden_size) + self.final_layer = FinalLayer(hidden_size, out_channels * patch_size ** 2) + self.num_cond_blocks = num_cond_blocks + + + self.weight_path = weight_path + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + DiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + + @lru_cache + def fetch_pos(self, height, width, device): + pos = precompute_freqs_cis_2d(self.hidden_size, height//self.patch_size, width//self.patch_size).to(device)[None, ...] + return pos + + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H, W, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + + if s is None: + # semantic encoder + s = self.s_embedder(x) + pos + c = nn.functional.silu(t + y) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s diff --git a/src/models/denoiser/flatten_condit_catdit_fixt.py b/src/models/denoiser/flatten_condit_catdit_fixt.py new file mode 100644 index 0000000..22a0fd5 --- /dev/null +++ b/src/models/denoiser/flatten_condit_catdit_fixt.py @@ -0,0 +1,314 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2 + hidden_size, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + # s = nn.functional.silu(t + s) + s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) + x = torch.cat((x, s), dim=-1) + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, c, pos, None) + x = self.final_layer(x, c) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_conv_fixt.py b/src/models/denoiser/flatten_condit_conv_fixt.py new file mode 100644 index 0000000..219db4c --- /dev/null +++ b/src/models/denoiser/flatten_condit_conv_fixt.py @@ -0,0 +1,340 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenConvBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, kernel_size=3): + super().__init__() + self.hidden_size = hidden_size + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = nn.Conv2d(hidden_size, hidden_size, groups=groups, kernel_size=kernel_size, stride=1, padding=kernel_size//2) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + attn_x = modulate(self.norm1(x), shift_msa, scale_msa) + attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() + attn_x = self.attn(attn_x) + attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) + x = x + gate_msa * attn_x + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + kernel_size=3, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups, kernel_size)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_convnext_fixt.py b/src/models/denoiser/flatten_condit_convnext_fixt.py new file mode 100644 index 0000000..cf9c214 --- /dev/null +++ b/src/models/denoiser/flatten_condit_convnext_fixt.py @@ -0,0 +1,339 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenConvBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.hidden_size = hidden_size + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = nn.Conv2d(hidden_size, hidden_size, groups=hidden_size, kernel_size=7, stride=1, padding=3) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + attn_x = modulate(self.norm1(x), shift_msa, scale_msa) + attn_x = attn_x.transpose(1, 2).view(-1, self.hidden_size, 16, 16).contiguous() + attn_x = self.attn(attn_x) + attn_x = attn_x.view(-1, self.hidden_size, 256).transpose(1, 2) + x = x + gate_msa * attn_x + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenConvBlock(self.hidden_size, self.num_groups)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_dit_fixt.py b/src/models/denoiser/flatten_condit_dit_fixt.py new file mode 100644 index 0000000..15557f3 --- /dev/null +++ b/src/models/denoiser/flatten_condit_dit_fixt.py @@ -0,0 +1,313 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_dit_norm_fixt.py b/src/models/denoiser/flatten_condit_dit_norm_fixt.py new file mode 100644 index 0000000..28034e3 --- /dev/null +++ b/src/models/denoiser/flatten_condit_dit_norm_fixt.py @@ -0,0 +1,314 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = torch.nn.functional.normalize(s, dim=-1, p=2, eps=1e-6) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py new file mode 100644 index 0000000..9a5e4fd --- /dev/null +++ b/src/models/denoiser/flatten_condit_encoder_decoder_fixt.py @@ -0,0 +1,429 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention +from src.utils.model_loader import ModelLoader +from src.utils.no_grad import no_grad + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiTEncoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos_rope = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + pos_ape = precompute_freqs_cis_2d(self.hidden_size*2, height, width).to(device) + self.precompute_pos[(height, width)] = (pos_rope, pos_ape) + return (pos_rope, pos_ape) + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + def forward(self, x, t, y, mask=None): + B, _, H, W = x.shape + pos_rope, pos_ape = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + s = self.s_embedder(x) + # s = s + pos_ape.to(s.dtype) + for i in range(self.num_blocks): + s = self.blocks[i](s, c, pos_rope, mask) + return None, s + + +class FlattenDiTDecoder(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)] + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, s, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + # s = torch.nn.functional.normalize(s, dim=-1, eps=1e-6) + s = torch.nn.functional.silu(t + s) + x = self.x_embedder(x) + for i in range(self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, + stride=self.patch_size) + return x + + + +class FlattenDiT(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + joint_training=False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + ModelLoader().load(encoder) + if not joint_training: + self.encoder = self.encoder.to(torch.bfloat16) + no_grad(self.encoder) + + self.joint_training = joint_training + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + +class FlattenDiTScalingEncoder(nn.Module): + def __init__( + self, + encoder:FlattenDiTEncoder, + decoder:FlattenDiTDecoder, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + no_grad(self.decoder) + + if self.encoder.weight_path: + weight = torch.load(self.encoder.weight_path, map_location=torch.device('cpu')) + if self.encoder.load_ema: + prefix = "ema_denoiser." + else: + prefix = "denoiser." + for k, v in self.encoder.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + print(f"Failed to copy {prefix+k} to denoiser weight") + + if self.decoder.weight_path: + weight = torch.load(self.decoder.weight_path, map_location=torch.device('cpu')) + if self.decoder.load_ema: + prefix = "ema_denoiser." + else: + prefix = "denoiser." + for k, v in self.decoder.state_dict().items(): + if "blocks." in k: + blockid = int(k.split("blocks.")[-1][0]) + k = k.replace(f"blocks.{blockid}", f"blocks.{int(blockid)+8}") + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + print(f"Failed to copy {prefix+k} to denoiser weight") + self.decoder = decoder.to(torch.bfloat16) + + def forward(self, x, t, y, s=None): + if s is None: + _, s = self.encoder(x, t, y) + x = self.decoder(x, t, y, s) + return x, s + diff --git a/src/models/denoiser/flatten_condit_mlp_fixt.py b/src/models/denoiser/flatten_condit_mlp_fixt.py new file mode 100644 index 0000000..40735e4 --- /dev/null +++ b/src/models/denoiser/flatten_condit_mlp_fixt.py @@ -0,0 +1,334 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + +class FlattenMLPBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = FeedForward(hidden_size, mlp_hidden_dim) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([]) + for i in range(self.num_cond_blocks): + self.blocks.append(FlattenDiTBlock(self.hidden_size, self.num_groups)) + for i in range(self.num_blocks-self.num_cond_blocks): + self.blocks.append(FlattenMLPBlock(self.hidden_size, self.num_groups)) + + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.s_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, None) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py new file mode 100644 index 0000000..bcf3315 --- /dev/null +++ b/src/models/denoiser/flatten_condit_sdown2_dit_fixt.py @@ -0,0 +1,321 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.s_embedder = Embed(in_channels*patch_size**4, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.s_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.s_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos_x = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + pos_s = self.fetch_pos(H//self.patch_size//2, W//self.patch_size//2, x.device) + s = torch.nn.functional.unfold(x, kernel_size=self.patch_size*2, stride=self.patch_size*2).transpose(1, 2) + s = self.s_embedder(s) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos_s, mask) + s = s.view(B, H//self.patch_size//2, W//self.patch_size//2, self.hidden_size) + s = torch.permute(s, (0, 3, 1, 2)) + s = torch.nn.functional.interpolate(s, scale_factor=2, mode='bilinear', align_corners=False) + s = torch.permute(s, (0, 2, 3, 1)) + s = s.view(B, -1, self.hidden_size) + s = nn.functional.silu(t + s) + + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos_x, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt.py b/src/models/denoiser/flatten_dit_fixt.py new file mode 100644 index 0000000..9412d6e --- /dev/null +++ b/src/models/denoiser/flatten_dit_fixt.py @@ -0,0 +1,306 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device, dtype): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device, dtype) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, masks=None): + if masks is None: + masks = [None, ]*self.num_blocks + if isinstance(masks, torch.Tensor): + masks = masks.unbind(0) + if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: + masks = masks + [None]*(self.num_blocks-len(masks)) + + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + B, L, C = x.shape + t = self.t_embedder(t.view(-1)).view(B, -1, C) + y = self.y_embedder(y).view(B, 1, C) + condition = nn.functional.silu(t + y) + for i, block in enumerate(self.blocks): + x = block(x, condition, pos, masks[i]) + x = self.final_layer(x, condition) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x \ No newline at end of file diff --git a/src/models/denoiser/flatten_dit_fixt_xvout.py b/src/models/denoiser/flatten_dit_fixt_xvout.py new file mode 100644 index 0000000..4df3393 --- /dev/null +++ b/src/models/denoiser/flatten_dit_fixt_xvout.py @@ -0,0 +1,311 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return (self.weight * hidden_states).to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, 2*in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device, dtype): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device, dtype) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device, dtype) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, y, masks=None): + if masks is None: + masks = [None, ]*self.num_blocks + if isinstance(masks, torch.Tensor): + masks = masks.unbind(0) + if isinstance(masks, (tuple, list)) and len(masks) < self.num_blocks: + masks = masks + [None]*(self.num_blocks-len(masks)) + + B, _, H, W = x.shape + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + x = self.x_embedder(x) + pos = self.fetch_pos(H // self.patch_size, W // self.patch_size, x.device, x.dtype) + B, L, C = x.shape + t = self.t_embedder(t.view(-1)).view(B, -1, C) + y = self.y_embedder(y).view(B, 1, C) + condition = nn.functional.silu(t + y) + for i, block in enumerate(self.blocks): + x = block(x, condition, pos, masks[i]) + x = self.final_layer(x, condition) + x0, v = x.chunk(2, dim=-1) + x0 = torch.nn.functional.fold(x0.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + v = torch.nn.functional.fold(v.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + if self.training: + return v, x0 + else: + return v \ No newline at end of file diff --git a/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py new file mode 100644 index 0000000..4e570b0 --- /dev/null +++ b/src/models/denoiser/flatten_sharepatch_condit_dit_fixt.py @@ -0,0 +1,308 @@ +import functools +from typing import Tuple +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from torch.nn.modules.module import T + +# from torch.nn.attention.flex_attention import flex_attention, create_block_mask +from torch.nn.functional import scaled_dot_product_attention + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class Embed(nn.Module): + def __init__( + self, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Linear(in_chans, embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class TimestepEmbedder(nn.Module): + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class LabelEmbedder(nn.Module): + def __init__(self, num_classes, hidden_size): + super().__init__() + self.embedding_table = nn.Embedding(num_classes, hidden_size) + self.num_classes = num_classes + + def forward(self, labels,): + embeddings = self.embedding_table(labels) + return embeddings + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 2*hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + return x + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + freqs_cis = freqs_cis[None, :, None, :] + # xq : B N H Hc + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2 + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class RAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = RMSNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # B N H Hc + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rotary_emb(q, k, freqs_cis=pos) + q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2) # B, H, N, Hc + k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() # B, H, N, Hc + v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous() + + x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class FlattenDiTBlock(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, ): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos, mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask) + x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FlattenConDiT(nn.Module): + def __init__( + self, + in_channels=4, + num_groups=12, + hidden_size=1152, + num_blocks=18, + num_cond_blocks=4, + patch_size=2, + num_classes=1000, + learn_sigma=True, + deep_supervision=0, + weight_path=None, + load_ema=False, + ): + super().__init__() + self.deep_supervision = deep_supervision + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.num_blocks = num_blocks + self.num_cond_blocks = num_cond_blocks + self.patch_size = patch_size + self.x_embedder = Embed(in_channels*patch_size**2, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = LabelEmbedder(num_classes+1, hidden_size) + + self.final_layer = FinalLayer(hidden_size, in_channels*patch_size**2) + + self.weight_path = weight_path + + self.load_ema = load_ema + self.blocks = nn.ModuleList([ + FlattenDiTBlock(self.hidden_size, self.num_groups) for _ in range(self.num_blocks) + ]) + self.initialize_weights() + self.precompute_pos = dict() + + def fetch_pos(self, height, width, device): + if (height, width) in self.precompute_pos: + return self.precompute_pos[(height, width)].to(device) + else: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width).to(device) + self.precompute_pos[(height, width)] = pos + return pos + + def initialize_weights(self): + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + + # Initialize label embedding table: + nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # # Zero-out adaLN modulation layers in SiT blocks: + # for block in self.blocks: + # nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + # nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + + def forward(self, x, t, y, s=None, mask=None): + B, _, H, W = x.shape + pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device) + x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size) + y = self.y_embedder(y).view(B, 1, self.hidden_size) + c = nn.functional.silu(t + y) + if s is None: + s = self.x_embedder(x) + for i in range(self.num_cond_blocks): + s = self.blocks[i](s, c, pos, mask) + s = nn.functional.silu(t + s) + + x = self.x_embedder(x) + for i in range(self.num_cond_blocks, self.num_blocks): + x = self.blocks[i](x, s, pos, None) + x = self.final_layer(x, s) + x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return x, s \ No newline at end of file diff --git a/src/models/denoiser/flowdcn.py b/src/models/denoiser/flowdcn.py new file mode 100644 index 0000000..92e2237 --- /dev/null +++ b/src/models/denoiser/flowdcn.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import math + +from torch.nn.init import zeros_ +from src.models.denoiser.base_model import BaseModel +from src.ops.triton_kernels.function import DCNFunction + +def modulate(x, shift, scale): + return x * (1 + scale[:, None, None]) + shift[:, None, None] + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer = None, + bias: bool = True, + ): + super().__init__() + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + def forward(self, x): + x = self.proj(x) + x = self.norm(x) + return x + +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + def forward(self, x): + b, h, w, c = x.shape + x = x.view(b, h*w, c) + x = self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x)) + x = x.view(b, h, w, c) + return x + + +class MultiScaleDCN(nn.Module): + def __init__(self, in_channels, groups, channels, kernels, deformable_biass=True): + super().__init__() + self.in_channels = in_channels + self.groups = groups + self.channels = channels + self.kernels = kernels + self.v = nn.Linear(in_channels, groups * channels, bias=True) + self.qk_deformables = nn.Linear(in_channels, groups * kernels * 2, bias=True) + self.qk_scales = nn.Linear(in_channels, groups * kernels, bias=False) + self.qk_weights = nn.Linear(in_channels, groups*kernels, bias=True) + self.out = nn.Linear(groups * channels, in_channels) + self.deformables_prior = nn.Parameter(torch.randn((1, 1, 1, 1, kernels, 2)), requires_grad=False) + self.deformables_scale = nn.Parameter(torch.ones((1, 1, 1, groups, 1, 1)), requires_grad=True) + self.max_scale = 6 + self._init_weights() + def _init_weights(self): + zeros_(self.qk_deformables.weight.data) + zeros_(self.qk_scales.weight.data) + zeros_(self.qk_deformables.bias.data) + zeros_(self.qk_weights.weight.data) + zeros_(self.v.bias.data) + zeros_(self.out.bias.data) + num_prior = int(self.kernels ** 0.5) + dx = torch.linspace(-1, 1, num_prior, device="cuda") + dy = torch.linspace(-1, 1, num_prior, device="cuda") + dxy = torch.meshgrid([dx, dy], indexing="xy") + dxy = torch.stack(dxy, dim=-1) + dxy = dxy.view(-1, 2) + self.deformables_prior.data[..., :num_prior*num_prior, :] = dxy + for i in range(self.groups): + scale = (i+1)/self.groups - 0.0001 + inv_scale = math.log((scale)/(1-scale)) + self.deformables_scale.data[..., i, :, :] = inv_scale + def forward(self, x): + B, H, W, _ = x.shape + v = self.v(x).view(B, H, W, self.groups, self.channels) + deformables = self.qk_deformables(x).view(B, H, W, self.groups, self.kernels, 2) + scale = self.qk_scales(x).view(B, H, W, self.groups, self.kernels, 1) + self.deformables_scale + deformables = (deformables + self.deformables_prior ) * scale.sigmoid()*self.max_scale + weights = self.qk_weights(x).view(B, H, W, self.groups, self.kernels) + out = DCNFunction.apply(v, deformables, weights) + out = out.view(B, H, W, -1) + out = self.out(out) + return out + +class FlowDCNBlock(nn.Module): + def __init__(self, hidden_size, groups, kernels=9, mlp_ratio=4.0, deformable_biass=True): + super().__init__() + self.norm1 = RMSNorm(hidden_size, eps=1e-6) + self.attn = MultiScaleDCN(hidden_size, groups=groups, channels=hidden_size//groups, kernels=kernels, deformable_biass=deformable_biass) + self.norm2 = RMSNorm(hidden_size, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp = FeedForward(hidden_size, mlp_hidden_dim) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) + x = x + gate_msa[:, None, None] * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp[:, None, None] * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + + + +class FlowDCN(BaseModel): + def __init__(self, deformable_biass=True, *args, **kwargs): + super().__init__(*args, **kwargs) + self.blocks = nn.ModuleList([ + FlowDCNBlock(self.hidden_size, self.num_groups, kernels=9, deformable_biass=deformable_biass) for _ in range(self.num_blocks) + ]) + self.x_embedder = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, bias=True) + self.initialize_weights() + + def forward(self, x, t, y): + batch_size, _, height, width = x.shape[0] + x = self.x_embedder(x) # (N, D, h, w) + x = x.permute(0, 2, 3, 1).reshape(batch_size, height*width//self.patch_size**2, -1) + t = self.t_embedder(t) # (N, D) + y = self.y_embedder(y, self.training) # (N, D) + c = t + y # (N, D) + B, L, C = x.shape + x = x.view(B, height//self.patch_size, width//self.patch_size, C) + for block in self.blocks: + x = block(x, c) # (N, T, D) + x = x.view(B, L, C) + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = torch.nn.functional.fold(x.transpose(1, 2), (height, width), kernel_size=self.patch_size, stride=self.patch_size) + if self.learn_sigma: + x, _ = torch.split(x, self.out_channels // 2, dim=1) + return x \ No newline at end of file diff --git a/src/models/encoder.py b/src/models/encoder.py new file mode 100644 index 0000000..8b7f96a --- /dev/null +++ b/src/models/encoder.py @@ -0,0 +1,132 @@ +import torch +import copy +import os +import timm +import transformers +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from torchvision.transforms import Normalize + +class RandViT(nn.Module): + def __init__(self, model_id, weight_path:str=None): + super(RandViT, self).__init__() + self.encoder = timm.create_model( + model_id, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class DINO(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINO, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([ 0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([ 1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + +class CLIP(nn.Module): + def __init__(self, model_id, weight_path:str): + super(CLIP, self).__init__() + self.encoder = transformers.CLIPVisionModel.from_pretrained(weight_path) + self.patch_size = self.encoder.vision_model.embeddings.patch_embedding.kernel_size + self.shifts = nn.Parameter(torch.tensor([0.0, + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0, + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + + +class DINOv2(nn.Module): + def __init__(self, model_id, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = transformers.Dinov2Model.from_pretrained(weight_path) + self.patch_size = self.encoder.embeddings.patch_embeddings.projection.kernel_size + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward(x)['last_hidden_state'][:, 1:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature \ No newline at end of file diff --git a/src/models/vae.py b/src/models/vae.py new file mode 100644 index 0000000..c47b087 --- /dev/null +++ b/src/models/vae.py @@ -0,0 +1,81 @@ +import torch +import subprocess +import lightning.pytorch as pl + +import logging + + +logger = logging.getLogger(__name__) +def class_fn_from_str(class_str): + class_module, from_class = class_str.rsplit(".", 1) + class_module = __import__(class_module, fromlist=[from_class]) + return getattr(class_module, from_class) + + +class BaseVAE(torch.nn.Module): + def __init__(self, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + + def encode(self, x): + return x/self.scale+self.shift + + def decode(self, x): + return (x-self.shift)*self.scale + + +# very bad bugs with nearest sampling +class DownSampleVAE(BaseVAE): + def __init__(self, down_ratio, scale=1.0, shift=0.0): + super().__init__() + self.model = torch.nn.Identity() + self.scale = scale + self.shift = shift + self.down_ratio = down_ratio + + def encode(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False) + return x/self.scale+self.shift + + def decode(self, x): + x = (x-self.shift)*self.scale + x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False) + return x + + + +class LatentVAE(BaseVAE): + def __init__(self, precompute=False, weight_path:str=None): + super().__init__() + self.precompute = precompute + self.model = None + self.weight_path = weight_path + + from diffusers.models import AutoencoderKL + setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path)) + self.scaling_factor = self.model.config.scaling_factor + + @torch.no_grad() + def encode(self, x): + assert self.model is not None + if self.precompute: + return x.mul_(self.scaling_factor) + return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor) + + @torch.no_grad() + def decode(self, x): + assert self.model is not None + return self.model.decode(x.div_(self.scaling_factor)).sample + + +def uint82fp(x): + x = x.to(torch.float32) + x = (x - 127.5) / 127.5 + return x + +def fp2uint8(x): + x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8) + return x + diff --git a/src/ops/cuda_kernels/backward.cu b/src/ops/cuda_kernels/backward.cu new file mode 100644 index 0000000..2e85d86 --- /dev/null +++ b/src/ops/cuda_kernels/backward.cu @@ -0,0 +1,346 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace cg = cooperative_groups; + +template +__device__ __always_inline int toInt(scalar_t val); + +template<> +__device__ __always_inline int toInt(float val){ + return static_cast(val); +} +template<> +__device__ __always_inline int toInt(half val){ + return __half2int_rz(val); +} + +template +__device__ __always_inline scalar_t fromInt(int val); + +template<> +__device__ __always_inline float fromInt(int val){ + return static_cast(val); +} + +template<> +__device__ __always_inline half fromInt(int val){ + return __int2half_rz(val); +} + +template +__device__ __always_inline scalar_t constVal(float val); + +template<> +__device__ __always_inline float constVal(float val) { + return (float)val; +} + +template<> +__device__ __always_inline half constVal(float val) { + return __float2half(val); // Using float to half conversion +} +template<> +__device__ __always_inline nv_bfloat16 constVal(float val){ + return __float2bfloat16(val); +} + + + + + +// B, H, W, C, BLOCK_DIM must be multiple of C +template +__global__ void dcn_backward_pipeline_kernel( + const int H, + const int W, + const int G, + const int K, + const int C, + scalar_t* ptr_values, + scalar_t* ptr_deformables, + scalar_t* ptr_weights, + scalar_t* ptr_grad_out, + scalar_t* ptr_grad_values, + scalar_t* ptr_grad_deformables, + scalar_t* ptr_grad_weights +) { + auto block = cg::this_thread_block(); + auto self_thread = cg::this_thread(); + auto tile_threads = cg::tiled_partition(block); + int local_thread_id = block.thread_rank(); + int local_tile_id = tile_threads.meta_group_rank(); + int num_local_tiles = tile_threads.meta_group_size(); + int global_tile_id = block.group_index().x*num_local_tiles + local_tile_id; + + extern __shared__ int shm[]; + auto GradBuffer = reinterpret_cast(shm); + scalar_t* Buffer = reinterpret_cast(shm) + num_local_tiles*C; + if(global_tile_id >= H*W*G) return; + + int bid = block.group_index().y; + int gid = global_tile_id % G; + int wid = global_tile_id / G % W; + int hid = global_tile_id / G / W; + int globale_offset = bid*H*W*G*C + global_tile_id*C; + cg::memcpy_async(tile_threads, GradBuffer+local_tile_id*C, ptr_grad_out+globale_offset, sizeof(scalar_t)*C); + + int shared_offset[pipeline_stages]; + for (int s = 0; s < pipeline_stages; ++s) { + shared_offset[s] = (s+pipeline_stages*local_thread_id)*(TILE_C*4); + } + + auto pipeline = cuda::make_pipeline(); + const int num_tiles_per_thread = C/TILE_C/TILE_THREADS; + + for(int k=0; k(wid); + y = ptr_deformables[offset*2 + 1] + fromInt(hid); +// x = fromInt(wid); +// y = fromInt(hid); + weight = ptr_weights[offset]; + } + tile_threads.sync(); + x = tile_threads.shfl(x, 0); + y = tile_threads.shfl(y, 0); + weight = tile_threads.shfl(weight, 0); + + int floor_x = toInt(x); + int floor_y = toInt(y); + int ceil_x = floor_x + 1; + int ceil_y = floor_y + 1; + + + scalar_t dodx = constVal(0.0f); + scalar_t dody = constVal(0.0f); + scalar_t dodw = constVal(0.0f); + + int start_c = tile_threads.thread_rank() * (C / TILE_THREADS); + + bool tl_flag = (floor_x >=0) and (floor_x =0) and (floor_y=0) and (ceil_x =0) and (floor_y=0) and (floor_x =0) and (ceil_y=0) and (ceil_x =0) and (ceil_y(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; + dodx = dodx + -weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+j+ 1] * GradBuffer[gbuffer_offset + j + 1]; + dody = dody + -weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+j + 1] * GradBuffer[gbuffer_offset + j + 1]; + { + vec2_t vtl_di; + vtl_di.x = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; + vtl_di.y = weight* (fromInt(ceil_x) - x) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j + 1]; + atomicAdd((vec2_t*)(ptr_grad_values + tl_global_base + compute_n * TILE_C + j), vtl_di); + } + } + + + if(tr_flag){ + // tr + dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + weight*(fromInt(ceil_y) - y) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+ 1]; + dody = dody + -weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C+j + 1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vtr_di; + vtr_di.x = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j]; + vtr_di.y = weight* (x - fromInt(floor_x)) * (fromInt(ceil_y) - y) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + tr_global_base + compute_n * TILE_C + j), vtr_di); + } + } + + if(bl_flag){ + // bl + dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + -weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + dody = dody + weight*(fromInt(ceil_x) - x) * Buffer[ibuffer_offset+TILE_C*2+j+1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vbl_di; + vbl_di.x = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; + vbl_di.y = weight* (fromInt(ceil_x) - x) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + bl_global_base + compute_n * TILE_C + j), vbl_di); + } + } + + + if(br_flag){ + // tr + dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j] * GradBuffer[gbuffer_offset + j]; + dodw = dodw + (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + dodx = dodx + weight*(y - fromInt(floor_y)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + dody = dody + weight*(x - fromInt(floor_x)) * Buffer[ibuffer_offset+TILE_C*3+j+1] * GradBuffer[gbuffer_offset + j+1]; + { + vec2_t vbr_di; + vbr_di.x = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j]; + vbr_di.y = weight* (x - fromInt(floor_x)) * (y - fromInt(floor_y)) * GradBuffer[gbuffer_offset + j+1]; + atomicAdd((vec2_t*)(ptr_grad_values + br_global_base + compute_n * TILE_C + j), vbr_di); + } + } + } + pipeline.consumer_release(); + } + for (int i = TILE_THREADS>>1; i > 0; i/=2) { + dodx = dodx + tile_threads.shfl_down(dodx, i); + dody = dody + tile_threads.shfl_down(dody, i); + dodw = dodw + tile_threads.shfl_down(dodw, i); + } + if (tile_threads.thread_rank() == 0) { + cuda::memcpy_async(ptr_grad_deformables + offset * 2, &dodx, sizeof(scalar_t), pipeline); + cuda::memcpy_async(ptr_grad_deformables + offset * 2 + 1, &dody, sizeof(scalar_t), pipeline); + cuda::memcpy_async(ptr_grad_weights + offset, &dodw, sizeof(scalar_t), pipeline); + } + } +} + + +using namespace torch; +template +void backward(const int B, + const int H, + const int W, + const int G, + const int K, + const int C, + torch::Tensor values, + torch::Tensor deformables, + torch::Tensor weights, + torch::Tensor grad_out, + torch::Tensor grad_values, + torch::Tensor grad_deformables, + torch::Tensor grad_weights +) { + int num_local_tiles =(THREADS/TILE_THREADS); + int num_global_tiles = (H*W*G+num_local_tiles-1)/num_local_tiles; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(num_global_tiles, B); + + int deformable_shm_size = 0; + int grad_out_shm_size = num_local_tiles*C; + int pipeline_shm_size = (pipeline_stages*TILE_C*4*THREADS); + + int shm_size = deformable_shm_size+grad_out_shm_size+pipeline_shm_size; +// printf("shm_size: %d\n", shm_size/512); +// printf("pipeline_size: %d\n", pipeline_shm_size/512); +// printf("grad_out_size: %d\n", grad_out_shm_size/512); + + + switch (values.type().scalarType()) { + case at::ScalarType::Half: + return dcn_backward_pipeline_kernel<<>>( + H, W, G, K, C, + reinterpret_cast(values.data_ptr()), + reinterpret_cast(deformables.data_ptr()), + reinterpret_cast(weights.data_ptr()), + reinterpret_cast(grad_out.data_ptr()), + reinterpret_cast(grad_values.data_ptr()), + reinterpret_cast(grad_deformables.data_ptr()), + reinterpret_cast(grad_weights.data_ptr()) + ); +// case at::ScalarType::BFloat16: +// return dcn_backward_pipeline_kernel<<>>( +// H, W, G, K, C, +// reinterpret_cast(values.data_ptr()), +// reinterpret_cast(deformables.data_ptr()), +// reinterpret_cast(weights.data_ptr()), +// reinterpret_cast(grad_out.data_ptr()), +// reinterpret_cast(grad_values.data_ptr()), +// reinterpret_cast(grad_deformables.data_ptr()), +// reinterpret_cast(grad_weights.data_ptr()) +// ); + default: + printf("running error"); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("backward_p1_c2_tile16_thread128", &backward<1, 2, 16, 128>, ""); + m.def("backward_p2_c2_tile16_thread128", &backward<2, 2, 16, 128>, ""); + m.def("backward_p1_c4_tile16_thread128", &backward<1, 4, 16, 128>, ""); + m.def("backward_p1_c2_tile16_thread256", &backward<1, 2, 16, 256>, ""); + m.def("backward_p2_c2_tile16_thread256", &backward<2, 2, 16, 256>, ""); + m.def("backward_p1_c4_tile16_thread256", &backward<1, 4, 16, 256>, ""); + m.def("backward_p1_c2_tile16_thread384", &backward<1, 2, 16, 384>, ""); + m.def("backward_p2_c2_tile16_thread384", &backward<2, 2, 16, 384>, ""); + m.def("backward_p1_c4_tile16_thread384", &backward<1, 4, 16, 384>, ""); + m.def("backward_p1_c2_tile16_thread512", &backward<1, 2, 16, 512>, ""); + m.def("backward_p2_c2_tile16_thread512", &backward<2, 2, 16, 512>, ""); + m.def("backward_p1_c4_tile16_thread512", &backward<1, 4, 16, 512>, ""); + m.def("backward_p1_c2_tile16_thread768", &backward<1, 2, 16, 768>, ""); + m.def("backward_p2_c2_tile16_thread768", &backward<2, 2, 16, 768>, ""); + m.def("backward_p1_c4_tile16_thread768", &backward<1, 4, 16, 768>, ""); +// m.def("backward_p1_c2_tile16_thread1024", &backward<1, 2, 16, 1024>, ""); +// m.def("backward_p2_c2_tile16_thread1024", &backward<2, 2, 16, 1024>, ""); +// m.def("backward_p1_c4_tile16_thread1024", &backward<1, 4, 16, 1024>, ""); + + m.def("backward_p1_c2_tile32_thread128", &backward<1, 2, 32, 128>, ""); + m.def("backward_p1_c2_tile32_thread256", &backward<1, 2, 32, 256>, ""); + m.def("backward_p1_c2_tile32_thread384", &backward<1, 2, 32, 384>, ""); + m.def("backward_p1_c2_tile32_thread512", &backward<1, 2, 32, 512>, ""); +} diff --git a/src/ops/cuda_kernels/bak_forward.cu b/src/ops/cuda_kernels/bak_forward.cu new file mode 100644 index 0000000..00569f8 --- /dev/null +++ b/src/ops/cuda_kernels/bak_forward.cu @@ -0,0 +1,289 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +template +__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ +#pragma unroll + for(int i=0; i +__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ +#pragma unroll + for(int i=0; i +__global__ void dcn_forward_kernel(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM + // __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + // loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + // loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); +#pragma unroll + for(int j=0; j +__global__ void dcn_forward_kernel_16(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + __shared__ math_t math_buffer[L][BLOCK_DIM]; //[BLOCK_DIM*H*W]; // H, W, BLOCK_DIM + __shared__ scalar_t io_buffer[L][BLOCK_DIM]; // H, W, BLOCK_DIM + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM/transfer_length; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + + } + + __syncthreads(); + +#pragma unroll + for(int i=0; i +void dcn_forward(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_16<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::BFloat16: + return dcn_forward_kernel_16<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::Float: + return dcn_forward_kernel<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + + +// PyBind11 bindings +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); +//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c4", &dcn_forward<256, 4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c8", &dcn_forward<256, 8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c16", &dcn_forward<256, 16, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); +} diff --git a/src/ops/cuda_kernels/forward.cu b/src/ops/cuda_kernels/forward.cu new file mode 100644 index 0000000..ac18308 --- /dev/null +++ b/src/ops/cuda_kernels/forward.cu @@ -0,0 +1,309 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +template +__device__ __always_inline void loop_mul_add(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_mul_load(TA* ptr_a, TB* ptr_b, TB weight, int stride_a, int stride_b, int n){ + #pragma unroll + for(int i=0; i +__device__ __always_inline void loop_load(TA* ptr_a, TB* ptr_b, int stride_a, int stride_b, int n){ +#pragma unroll + for(int i=0; i +__device__ __always_inline void loop_reset(TA* ptr_a, int stride, int n){ +#pragma unroll + for(int i=0; i +__global__ void dcn_forward_kernel_register(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + extern __shared__ int shm[]; + math_t* math_buffer = reinterpret_cast(shm); + + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + int offset2 = (bid*H*W*G*C + job_id*C*G + gid*C + base_c); +#pragma unroll + for(int j=0; j +__global__ void dcn_forward_kernel_pipeline(const int H, const int W, const int C, scalar_t* ptr_value, scalar_t* ptr_deformables, scalar_t* ptr_weights, scalar_t* ptr_out){ + int work_id = threadIdx.x; + int bid = blockIdx.z; + int gid = blockIdx.y; + int G = gridDim.y; + int c_blockid = blockIdx.x; + int work_load = (H*W/blockDim.x); + + extern __shared__ int shm[]; + math_t* math_buffer = reinterpret_cast(shm); + scalar_t* io_buffer = reinterpret_cast(shm) + H*W*BLOCK_DIM*sizeof(math_t)/sizeof(scalar_t); + math_t register_bufferA[BLOCK_DIM] = {0}; + int base_c = c_blockid*BLOCK_DIM; + + int num_transfers = BLOCK_DIM/transfer_length; +#pragma unroll + for(int i=0; i(register_bufferA, 1, BLOCK_DIM); + loop_reset((scalar_t*)&io_buffer[hid*W+wid], 1, BLOCK_DIM); +#pragma unroll + for(int k=0; k(register_bufferA, (math_t*)&math_buffer[floor_y*W+floor_x], tl_weight, 1, 1, BLOCK_DIM); + } + // load top right + math_t tr_weight = (x - floor_x)*(ceil_y - y)*weight; + if((0<= floor_y) and (floor_y < H) and (0<= ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[floor_y*W + ceil_x], tr_weight, 1, 1, BLOCK_DIM); + } + + // load bottom left + math_t bl_weight = (ceil_x - x)*(y - floor_y)*weight; + if((0<= ceil_y) and (ceil_y < H) and (0<= floor_x) and (floor_x < W) ){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+floor_x], bl_weight, 1, 1, BLOCK_DIM); + } + // load bottom right + math_t br_weight = (x - floor_x)*(y - floor_y)*weight; + if((0<=ceil_y) and (ceil_y < H) and (0<=ceil_x) and (ceil_x < W)){ + loop_mul_add(register_bufferA, (math_t*)&math_buffer[ceil_y*W+ceil_x], br_weight, 1, 1, BLOCK_DIM); + } + + } + loop_load((scalar_t*)&io_buffer[hid*W+wid], register_bufferA, 1, 1, BLOCK_DIM); + + } + + __syncthreads(); + +#pragma unroll + for(int i=0; i +void dcn_forward(const int B, const int G, const int C, const int H, const int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + int shm_size = H*W*C_BLOCK_DIM*sizeof(at::Half); + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_register<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::Float: + return dcn_forward_kernel_register<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + +template +void dcn_forward_pipeline(int B, int G, int C, int H, int W, torch::Tensor value, torch::Tensor deformables, torch::Tensor weights, torch::Tensor out) { + + int NUM_C_BLOCK = (C+C_BLOCK_DIM-1)/C_BLOCK_DIM; + dim3 launch_threads_per_block(THREADS); + dim3 launch_blocks(NUM_C_BLOCK, G, B); + int shm_size = 2*H*W*C_BLOCK_DIM*sizeof(at::Half); + switch (value.type().scalarType()) { + case at::ScalarType::Half: + return dcn_forward_kernel_pipeline<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + case at::ScalarType::BFloat16: + return dcn_forward_kernel_pipeline<<>>( + H, W, C, + value.data_ptr(), + deformables.data_ptr(), + weights.data_ptr(), + out.data_ptr()); + default: + printf("running error"); + } +} + +// PyBind11 bindings +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +//m.def("dcn_forward_c1_f4", &dcn_forward<1, 4>, "CUDA dcn forward"); +//m.def("dcn_forward_c2_f4", &dcn_forward<2, 4>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c4", &dcn_forward<4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c8", &dcn_forward<8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_l256_c16", &dcn_forward<16, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c4", &dcn_forward_pipeline<4, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c8", &dcn_forward_pipeline<8, 256>, "CUDA dcn forward"); +m.def("dcn_forward_pipeline_l256_c16", &dcn_forward_pipeline<16, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l256_c32", &dcn_forward<256, 32, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c2", &dcn_forward<1024, 2, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c4", &dcn_forward<1024, 4, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c8", &dcn_forward<1024, 8, 256>, "CUDA dcn forward"); +// m.def("dcn_forward_l1024_c12", &dcn_forward<1024, 12, 256>, "CUDA dcn forward"); +} diff --git a/src/ops/cuda_kernels/forward.py b/src/ops/cuda_kernels/forward.py new file mode 100644 index 0000000..4ea9c5e --- /dev/null +++ b/src/ops/cuda_kernels/forward.py @@ -0,0 +1,95 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE': 32,}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def forward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_channels_per_group + C: tl.constexpr, # num_groups + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] + weights_ptr, # weights [B, H, W, G, K] + out_ptr, # out [B, H, W, G, C] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + + for block_base in tl.static_range(0, C, BLOCK_SIZE): + buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + for k in tl.static_range(K): + deformable_offset = (common_offset * K + k) * 2 + + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) + + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + + + + tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input = tl_block_input * tl_weight + + # load top right + tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input = tr_block_input * tr_weight + # load bottom left + bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input = bl_block_input * bl_weight + # load bottom right + br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input = br_block_input * br_weight + + # sampled + sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input + + weighted_sampled_input = sampled_input * weight + buffer = buffer + weighted_sampled_input + # store to out_ptr + tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) + diff --git a/src/ops/cuda_kernels/function.py b/src/ops/cuda_kernels/function.py new file mode 100644 index 0000000..9d4bfad --- /dev/null +++ b/src/ops/cuda_kernels/function.py @@ -0,0 +1,126 @@ +import time +import dcn_cuda_backward +import dcn_cuda_forward + +import math +import torch +from typing import Any +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_fwd, custom_bwd +from .forward import forward_kernel + + +class DCNFunction(Function): + BP_FUNCS = [ + dcn_cuda_backward.backward_p1_c2_tile16_thread128, + dcn_cuda_backward.backward_p1_c4_tile16_thread128, + dcn_cuda_backward.backward_p2_c2_tile16_thread128, + dcn_cuda_backward.backward_p1_c2_tile16_thread256, + dcn_cuda_backward.backward_p1_c4_tile16_thread256, + dcn_cuda_backward.backward_p2_c2_tile16_thread256, + dcn_cuda_backward.backward_p1_c2_tile16_thread384, + dcn_cuda_backward.backward_p1_c4_tile16_thread384, + dcn_cuda_backward.backward_p2_c2_tile16_thread384, + dcn_cuda_backward.backward_p1_c2_tile16_thread512, + dcn_cuda_backward.backward_p1_c4_tile16_thread512, + dcn_cuda_backward.backward_p2_c2_tile16_thread512, + dcn_cuda_backward.backward_p1_c2_tile16_thread768, + dcn_cuda_backward.backward_p1_c4_tile16_thread768, + dcn_cuda_backward.backward_p2_c2_tile16_thread768, + dcn_cuda_backward.backward_p1_c2_tile32_thread128, + dcn_cuda_backward.backward_p1_c2_tile32_thread256, + dcn_cuda_backward.backward_p1_c2_tile32_thread384, + dcn_cuda_backward.backward_p1_c2_tile32_thread512, + ] + FW_FUNCS = [ + dcn_cuda_forward.dcn_forward_l256_c4, + dcn_cuda_forward.dcn_forward_l256_c8, + dcn_cuda_forward.dcn_forward_l256_c16, + ] + BP_TABLES = dict() + FW_TABLES = dict() + + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx: Any, values, deformables, weights) -> Any: + B, H, W, G, C = values.shape + func = DCNFunction.find_fw_funcs(values, deformables, weights) + out = torch.zeros_like(values) + func(B, G, C, H, W, values, deformables, weights, out) + return out + + @staticmethod + def find_fw_funcs(values, deformables, weights): + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + hash_value = 10000 * B + 100 * H + W + 1000 * G + if hash_value in DCNFunction.FW_TABLES.keys(): + return DCNFunction.FW_TABLES[hash_value] + print("missing") + candicate_func = None + min_t = 999.0 + outs = torch.zeros_like(values) + for func in DCNFunction.FW_FUNCS: + t = [] + for i in range(100): + torch.cuda.synchronize() + start_t = time.time() + func(B, G, C, H, W, values, deformables, weights, outs) + torch.cuda.synchronize() + t.append(time.time() - start_t) + t = t[-50:] + t = sum(t) / len(t) + if t < min_t: + min_t = t + DCNFunction.FW_TABLES[hash_value] = func + candicate_func = func + assert candicate_func is not None + print(candicate_func) + return candicate_func + @staticmethod + def find_bp_funcs(values, deformables, weights, grad_out): + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + hash_value = 10000 * B + 100 * H + W + 1000 * G + if hash_value in DCNFunction.BP_TABLES.keys(): + return DCNFunction.BP_TABLES[hash_value] + print("missing") + candicate_func = None + min_t = 999.0 + grad_values = torch.zeros_like(values) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + for func in DCNFunction.BP_FUNCS: + t = [] + for i in range(100): + torch.cuda.synchronize() + start_t = time.time() + func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) + torch.cuda.synchronize() + t.append(time.time() - start_t) + t = t[-50:] + t = sum(t) / len(t) + if t < min_t: + min_t = t + DCNFunction.BP_TABLES[hash_value] = func + candicate_func = func + assert candicate_func is not None + print(candicate_func) + return candicate_func + + @staticmethod + @once_differentiable + @custom_bwd + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_out = grad_outputs[0] + values, deformables, weights = ctx.saved_tensors + B, H, W, G, C = values.shape + B, H, W, G, K = weights.shape + func = DCNFunction.find_bp_funcs(values, deformables, weights, grad_out) + grad_values = torch.zeros_like(values) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + func(B, H, W, G, K, C, values, deformables, weights, grad_out, grad_values, grad_deformables, grad_weights) + return grad_values, grad_deformables, grad_weights \ No newline at end of file diff --git a/src/ops/cuda_kernels/setup.py b/src/ops/cuda_kernels/setup.py new file mode 100644 index 0000000..34079d4 --- /dev/null +++ b/src/ops/cuda_kernels/setup.py @@ -0,0 +1,59 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='dcn_cuda_forward', + ext_modules=[ + CUDAExtension('dcn_cuda_forward', ['./forward.cu',], + extra_compile_args={'cxx': [], 'nvcc': [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "--use_fast_math", + "-O3", + ]} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + +setup( + name='dcn_cuda_backward', + ext_modules=[ + CUDAExtension('dcn_cuda_backward', ['./backward.cu',], + extra_compile_args={'cxx': [], 'nvcc': [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + "--use_fast_math", + "-O3", + ]} + ), + ], + cmdclass={ + 'build_ext': BuildExtension + } +) + + +# setup( +# name='mycuda', +# ext_modules=[ +# CUDAExtension('mycuda', ['./backward.cu',], +# extra_compile_args={'cxx': [], 'nvcc': [ +# "-O3", +# "-DCUDA_HAS_FP16=1", +# "-D__CUDA_NO_HALF_OPERATORS__", +# "-D__CUDA_NO_HALF_CONVERSIONS__", +# "-D__CUDA_NO_HALF2_OPERATORS__", +# ]} +# ), +# ], +# cmdclass={ +# 'build_ext': BuildExtension +# } +# ) \ No newline at end of file diff --git a/src/ops/triton_kernels/__init__.py b/src/ops/triton_kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ops/triton_kernels/backward.py b/src/ops/triton_kernels/backward.py new file mode 100644 index 0000000..e886aa2 --- /dev/null +++ b/src/ops/triton_kernels/backward.py @@ -0,0 +1,124 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=1), + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def backward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_groups + C: tl.constexpr, # num_channels_per_group + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + weights_ptr, # weights [B, H, W, G, K] + grad_ptr, # out [B, H, W, G, C] + grad_input_ptr, # input features [B, H, W, G, C] + grad_deformable_ptr, # deformable offsets [B, H, W, G, K, 2] + grad_weights_ptr, # weights [B, H, W, G, K] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + for k in tl.static_range(K): + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + dodx = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dody = tl.zeros((1,), dtype=grad_deformable_ptr.type.element_ty) + dodw = tl.zeros((1,), dtype=grad_weights_ptr.type.element_ty) + deformable_offset = (common_offset * K + k)*2 + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + for block_base in tl.static_range(0, C, BLOCK_SIZE): + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + grad = tl.load(grad_ptr+common_offset*C + block_offset, mask=block_mask, other=0.0) + dods = weight*grad + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) + block_offset + tl_block_mask = ((floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H)) + tl_block_input = tl.load(input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input_dot_grad = tl.sum(tl_block_input*grad, axis=0) + dodx = dodx + -1 * tl_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + tl_block_input_dot_grad * tl_weight + + dodtl = dods * tl_weight + tl.atomic_add(grad_input_ptr + tl_block_offset, mask=tl_block_mask & block_mask, val=dodtl) + + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + tr_block_mask = ((floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0)) + tr_block_input = tl.load(input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input_dot_grad = tl.sum(tr_block_input*grad, axis=0) + dodx = dodx + 1 * tr_block_input_dot_grad * (ceil_y - y) + dody = dody + -1 * tr_block_input_dot_grad * (x - floor_x) + dodw = dodw + tr_block_input_dot_grad*tr_weight + + dodtr = dods * tr_weight + tl.atomic_add(grad_input_ptr + tr_block_offset, mask=tr_block_mask & block_mask, val=dodtr) + + + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) + block_offset + bl_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0)) + bl_block_input = tl.load(input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input_dot_grad = tl.sum(bl_block_input*grad, axis=0) + dodx = dodx + -1 * bl_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * bl_block_input_dot_grad * (ceil_x - x) + dodw = dodw + bl_block_input_dot_grad*bl_weight + + dodbl = dods * bl_weight + tl.atomic_add(grad_input_ptr + bl_block_offset, mask=bl_block_mask & block_mask, val=dodbl) + + + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) + block_offset + br_block_mask = ((ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0)) + br_block_input = tl.load(input_ptr + br_block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input_dot_grad = tl.sum(br_block_input*grad, axis=0)*br_block_mask + + dodx = dodx + 1 * br_block_input_dot_grad * (y - floor_y) + dody = dody + 1 * br_block_input_dot_grad * (x - floor_x) + dodw = dodw + br_block_input_dot_grad*br_weight + + dodbr = dods * br_weight + tl.atomic_add(grad_input_ptr + br_block_offset, mask=br_block_mask & block_mask, val=dodbr) + dodx = dodx * weight + dody = dody * weight + tl.store(grad_weights_ptr + weights_offset + tl.arange(0, 1), dodw, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + tl.arange(0, 1), dodx, mask=id_mask) + tl.store(grad_deformable_ptr + deformable_offset + 1 + tl.arange(0, 1), dody, mask=id_mask) + + + + + diff --git a/src/ops/triton_kernels/forward.py b/src/ops/triton_kernels/forward.py new file mode 100644 index 0000000..cf7c243 --- /dev/null +++ b/src/ops/triton_kernels/forward.py @@ -0,0 +1,94 @@ +import triton +import triton.language as tl + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 64,}, num_stages=1, num_warps=2), + # triton.Config({'BLOCK_SIZE': 64, }, num_stages=1, num_warps=1), + ], + key=['B', 'H', 'W', 'G', 'C', 'K'], +) +@triton.jit +def forward_kernel( + B: tl.constexpr, + H: tl.constexpr, # image_size_h + W: tl.constexpr, # image_size_w + G: tl.constexpr, # num_channels_per_group + C: tl.constexpr, # num_groups + K: tl.constexpr, # kernel size + input_ptr, # input features [B, H, W, G, C] + deformable_ptr, # deformable offsets [B, H, W, G, 2*K + K] + weights_ptr, # weights [B, H, W, G, K] + out_ptr, # out [B, H, W, G, C] + BLOCK_SIZE: tl.constexpr, # a micro block to process in the Group +): + pid = tl.program_id(0) + wid = pid % W + hid = pid // W % H + gid = pid // (W * H) % G + bid = pid // (W * H * G) + + id_mask = (hid < H) & (wid < W) & (gid < G) & (bid < B) + common_offset = bid*H*W*G + hid*W*G + wid*G + gid + batch_base = bid * H * W * G * C + + for block_base in tl.static_range(0, C, BLOCK_SIZE): + buffer = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) + block_offset = tl.arange(0, BLOCK_SIZE) + block_base + block_mask = (block_offset < C) & id_mask + for k in tl.static_range(K): + deformable_offset = (common_offset * K + k) * 2 + + x = tl.load(deformable_ptr + deformable_offset, mask=id_mask, other=0.0) + wid + y = tl.load(deformable_ptr + deformable_offset + 1, mask=id_mask, other=0.0) + hid + + floor_x = x.to(tl.int32) + floor_y = y.to(tl.int32) + ceil_x = floor_x + 1 + ceil_y = floor_y + 1 + + # load top left + tl_weight = (ceil_x - x) * (ceil_y - y) + tl_block_offset = (batch_base + floor_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + tl_block_mask = (floor_y >= 0) & (floor_x >= 0) & (floor_x < W) & (floor_y < H) + + # load top right + tr_weight = (x - floor_x) * (ceil_y - y) + tr_block_offset = (batch_base + floor_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + tr_block_mask = (floor_y >= 0) & (ceil_x < W) & (floor_y < H) & (ceil_x >= 0) + # load bottom left + bl_weight = (ceil_x - x) * (y - floor_y) + bl_block_offset = (batch_base + ceil_y * W * G * C + floor_x * G * C + gid * C) #+ k * BLOCK_SIZE + bl_block_mask = (ceil_y < H) & (ceil_y >= 0) & (floor_x < W) & (floor_x >= 0) + # load bottom right + br_weight = (x - floor_x) * (y - floor_y) + br_block_offset = (batch_base + ceil_y * W * G * C + ceil_x * G * C + gid * C) #+ k * BLOCK_SIZE + br_block_mask = (ceil_y < H) & (ceil_y >= 0) & (ceil_x < W) & (ceil_x >= 0) + + # load dynamic weight and mask + weights_offset = common_offset*K + k + weight = tl.load(weights_ptr + weights_offset, mask=id_mask, other=0.0) + + + + tl_block_input = tl.load(input_ptr + tl_block_offset + block_offset, mask=tl_block_mask & block_mask, other=0.0) + tl_block_input = tl_block_input * tl_weight + + # load top right + tr_block_input = tl.load(input_ptr + tr_block_offset + block_offset, mask=tr_block_mask & block_mask, other=0.0) + tr_block_input = tr_block_input * tr_weight + # load bottom left + bl_block_input = tl.load(input_ptr + bl_block_offset + block_offset, mask=bl_block_mask & block_mask, other=0.0) + bl_block_input = bl_block_input * bl_weight + # load bottom right + br_block_input = tl.load(input_ptr + br_block_offset + block_offset, mask=br_block_mask & block_mask, other=0.0) + br_block_input = br_block_input * br_weight + + # sampled + sampled_input = tl_block_input + tr_block_input + bl_block_input + br_block_input + + weighted_sampled_input = sampled_input * weight + buffer = buffer + weighted_sampled_input + # store to out_ptr + tl.store(out_ptr + common_offset*C + block_offset, buffer, mask=block_mask) + diff --git a/src/ops/triton_kernels/function.py b/src/ops/triton_kernels/function.py new file mode 100644 index 0000000..84987a1 --- /dev/null +++ b/src/ops/triton_kernels/function.py @@ -0,0 +1,48 @@ +import torch +import triton +from typing import Any +from torch.autograd import Function +from torch.cuda.amp.autocast_mode import custom_bwd, custom_fwd +from .forward import forward_kernel +from .backward import backward_kernel + + + +class DCNFunction(Function): + + @staticmethod + @custom_fwd + def forward(ctx: Any, inputs, deformables, weights) -> Any: + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + out = torch.zeros_like(inputs) + grid = lambda META: (B * H * W * G,) + + forward_kernel[grid](B, H, W, G, C, K, inputs, deformables, weights, out) + ctx.save_for_backward(inputs, deformables, weights) + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_output = grad_outputs[0].contiguous() + + inputs, deformables, weights = ctx.saved_tensors + B, H, W, G, C = inputs.shape + _, _, _, _, K, _ = deformables.shape + + grad_inputs = torch.zeros_like(inputs) + grad_deformables = torch.zeros_like(deformables) + grad_weights = torch.zeros_like(weights) + grid = lambda META: (B * H * W * G,) + backward_kernel[grid]( + B, H, W, G, C, K, + inputs, + deformables, + weights, + grad_output, + grad_inputs, + grad_deformables, + grad_weights, + ) + return (grad_inputs, grad_deformables, grad_weights) \ No newline at end of file diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plugins/bd_env.py b/src/plugins/bd_env.py new file mode 100644 index 0000000..c1900e9 --- /dev/null +++ b/src/plugins/bd_env.py @@ -0,0 +1,70 @@ +import torch +import os +import socket +from typing_extensions import override +from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.plugins.environments.lightning import LightningEnvironment + + +class BDEnvironment(LightningEnvironment): + pass + # def __init__(self) -> None: + # super().__init__() + # self._global_rank: int = 0 + # self._world_size: int = 1 + # + # @property + # @override + # def creates_processes_externally(self) -> bool: + # """Returns whether the cluster creates the processes or not. + # + # If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the + # process launcher/job scheduler and Lightning will not launch new processes. + # + # """ + # return "LOCAL_RANK" in os.environ + # + # @staticmethod + # @override + # def detect() -> bool: + # assert "ARNOLD_WORKER_0_HOST" in os.environ.keys() + # assert "ARNOLD_WORKER_0_PORT" in os.environ.keys() + # return True + # + # @override + # def world_size(self) -> int: + # return self._world_size + # + # @override + # def set_world_size(self, size: int) -> None: + # self._world_size = size + # + # @override + # def global_rank(self) -> int: + # return self._global_rank + # + # @override + # def set_global_rank(self, rank: int) -> None: + # self._global_rank = rank + # rank_zero_only.rank = rank + # + # @override + # def local_rank(self) -> int: + # return int(os.environ.get("LOCAL_RANK", 0)) + # + # @override + # def node_rank(self) -> int: + # return int(os.environ.get("ARNOLD_ID")) + # + # @override + # def teardown(self) -> None: + # if "WORLD_SIZE" in os.environ: + # del os.environ["WORLD_SIZE"] + # + # @property + # def main_address(self) -> str: + # return os.environ.get("ARNOLD_WORKER_0_HOST") + # + # @property + # def main_port(self) -> int: + # return int(os.environ.get("ARNOLD_WORKER_0_PORT")) diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/copy.py b/src/utils/copy.py new file mode 100644 index 0000000..62cd89d --- /dev/null +++ b/src/utils/copy.py @@ -0,0 +1,13 @@ +import torch + +@torch.no_grad() +def copy_params(src_model, dst_model): + for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()): + dst_param.data.copy_(src_param.data) + +@torch.no_grad() +def swap_tensors(tensor1, tensor2): + tmp = torch.empty_like(tensor1) + tmp.copy_(tensor1) + tensor1.copy_(tensor2) + tensor2.copy_(tmp) \ No newline at end of file diff --git a/src/utils/model_loader.py b/src/utils/model_loader.py new file mode 100644 index 0000000..7d99166 --- /dev/null +++ b/src/utils/model_loader.py @@ -0,0 +1,29 @@ +from typing import Dict, Any, Optional + +import torch +import torch.nn as nn +from lightning.fabric.utilities.types import _PATH + + +import logging +logger = logging.getLogger(__name__) + +class ModelLoader: + def __init__(self,): + super().__init__() + + def load(self, denoiser, prefix=""): + if denoiser.weight_path: + weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu')) + + if denoiser.load_ema: + prefix = "ema_denoiser." + prefix + else: + prefix = "denoiser." + prefix + + for k, v in denoiser.state_dict().items(): + try: + v.copy_(weight["state_dict"][prefix+k]) + except: + logger.warning(f"Failed to copy {prefix+k} to denoiser weight") + return denoiser \ No newline at end of file diff --git a/src/utils/no_grad.py b/src/utils/no_grad.py new file mode 100644 index 0000000..2fd71de --- /dev/null +++ b/src/utils/no_grad.py @@ -0,0 +1,16 @@ +import torch + +@torch.no_grad() +def no_grad(net): + for param in net.parameters(): + param.requires_grad = False + net.eval() + return net + +@torch.no_grad() +def filter_nograd_tensors(params_list): + filtered_params_list = [] + for param in params_list: + if param.requires_grad: + filtered_params_list.append(param) + return filtered_params_list \ No newline at end of file diff --git a/src/utils/patch_bugs.py b/src/utils/patch_bugs.py new file mode 100644 index 0000000..db9a174 --- /dev/null +++ b/src/utils/patch_bugs.py @@ -0,0 +1,17 @@ +import torch +import lightning.pytorch.loggers.wandb as wandb + +setattr(wandb, '_WANDB_AVAILABLE', True) +torch.set_float32_matmul_precision('medium') + +import logging +logger = logging.getLogger("wandb") +logger.setLevel(logging.WARNING) + +import os +os.environ["NCCL_DEBUG"] = "WARN" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action='ignore', category=UserWarning) diff --git a/tools/cache_imlatent3.py b/tools/cache_imlatent3.py new file mode 100644 index 0000000..640cdb0 --- /dev/null +++ b/tools/cache_imlatent3.py @@ -0,0 +1,117 @@ +from diffusers import AutoencoderKL + +import torch +from typing import Callable +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import torch +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import pathlib + +import torch +import random +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + writer_pool = ThreadPoolExecutor(8) + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + # tvtf.RandomHorizontalFlip(p=1), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + # dataset = ImageNet(root='/tmp', split="train", transform=transforms, ) + B = 256 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16) + vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda') + + from accelerate import Accelerator + + accelerator = Accelerator() + + vae, dataloader = accelerator.prepare(vae, dataloader) + rank = accelerator.process_index + with torch.no_grad(): + for i, (image, label, path_list) in enumerate(dataloader): + # if i >= 128: break + new_path_list = [] + for p in path_list: + p = p + ".pt" + p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train", + "/mnt/bn/wangshuai6/data/ImageNet/train_256latent") + new_path_list.append(p) + + image = image.to("cuda") + distribution = vae.module.encode(image).latent_dist + mean = distribution.mean + logvar = distribution.logvar + for j in range(B): + out = dict( + mean=mean[j].cpu(), + logvar=logvar[j].cpu(), + ) + writer_pool.submit(save, out, new_path_list[j]) + writer_pool.shutdown(wait=True) + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/cache_imlatent4.py b/tools/cache_imlatent4.py new file mode 100644 index 0000000..fc33fa7 --- /dev/null +++ b/tools/cache_imlatent4.py @@ -0,0 +1,123 @@ +from diffusers import AutoencoderKL + +import torch +from typing import Callable +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import torch +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import pathlib + +import torch +import random +from torchvision.io.image import read_image +import torchvision.transforms as tvtf +from torch.utils.data import Dataset +from torchvision.datasets import ImageNet + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + writer_pool = ThreadPoolExecutor(8) + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(512), + # tvtf.RandomHorizontalFlip(p=1), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 8 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16) + vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda') + vae = vae.to(torch.float16) + from accelerate import Accelerator + + accelerator = Accelerator() + + vae, dataloader = accelerator.prepare(vae, dataloader) + rank = accelerator.process_index + with torch.no_grad(): + for i, (image, label, path_list) in enumerate(dataloader): + print(i/len(dataloader)) + flag = False + new_path_list = [] + for p in path_list: + p = p + ".pt" + p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train", + "/mnt/bn/wangshuai6/data/ImageNet/train_512_latent") + new_path_list.append(p) + if not os.path.exists(p): + print(p) + flag = True + + if flag: + image = image.to("cuda") + image = image.to(torch.float16) + distribution = vae.module.encode(image).latent_dist + mean = distribution.mean + logvar = distribution.logvar + + for j in range(len(path_list)): + out = dict( + mean=mean[j].cpu(), + logvar=logvar[j].cpu(), + ) + writer_pool.submit(save, out, new_path_list[j]) + writer_pool.shutdown(wait=True) + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/cat_images.py b/tools/cat_images.py new file mode 100644 index 0000000..21734b0 --- /dev/null +++ b/tools/cat_images.py @@ -0,0 +1,43 @@ +import cv2 +import numpy as np +import os +import pathlib +import argparse + +def group_images(path_list): + sorted(path_list) + class_id_dict = {} + for path in path_list: + class_id = str(path.name).split('_')[0] + if class_id not in class_id_dict: + class_id_dict[class_id] = [] + class_id_dict[class_id].append(path) + return class_id_dict + +def cat_images(path_list): + imgs = [] + for path in path_list: + img = cv2.imread(str(path)) + os.remove(path) + imgs.append(img) + row_cat_images = [] + row_length = int(len(imgs)**0.5) + for i in range(len(imgs)//row_length): + row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1)) + cat_image = np.concatenate(row_cat_images, axis=0) + return cat_image + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--src_dir', type=str, default=None) + + args = parser.parse_args() + src_dir = args.src_dir + path_list = list(pathlib.Path(src_dir).glob('*.png')) + class_id_dict = group_images(path_list) + for class_id, path_list in class_id_dict.items(): + cat_image = cat_images(path_list) + cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg') + # cat_path = "cat_{}.png".format(class_id) + cv2.imwrite(cat_path, cat_image) + diff --git a/tools/classifer_training.py b/tools/classifer_training.py new file mode 100644 index 0000000..b00b435 --- /dev/null +++ b/tools/classifer_training.py @@ -0,0 +1,353 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +# NORMALIZE_DATA = dict( +# dinov2_vits14a = dict( +# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79, +# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49, +# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67, +# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12, +# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43, +# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61, +# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23, +# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57, +# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65, +# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13, +# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27, +# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94, +# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07, +# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56, +# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50, +# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29, +# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78, +# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03, +# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02, +# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25, +# ], +# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92, +# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90, +# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03, +# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21, +# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04, +# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68, +# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71, +# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88, +# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86, +# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44, +# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15, +# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49, +# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90, +# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87, +# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81, +# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89, +# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97, +# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96, +# 2.62,3.23,2.13,2.29,2.24,2.72 +# ] +# ), +# dinov2_vitb14a = dict( +# mean=[ +# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82, +# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17, +# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31, +# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15, +# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03, +# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54, +# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43, +# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26 +# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05, +# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19, +# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36, +# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00, +# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14, +# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44, +# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24, +# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25, +# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12, +# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36, +# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10, +# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22, +# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28, +# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23, +# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62, +# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26, +# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12, +# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07, +# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57, +# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62, +# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10, +# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51, +# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22, +# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36, +# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10, +# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02, +# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42, +# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12, +# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60, +# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35, +# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64, +# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31 +# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00, +# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73, +# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05, +# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35, +# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05, +# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54, +# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03, +# ], +# std=[ +# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40, +# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39, +# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56, +# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73, +# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49, +# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42, +# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47, +# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49, +# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05, +# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46, +# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67, +# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91, +# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48, +# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76, +# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37, +# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52, +# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42, +# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56, +# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43, +# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47, +# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54, +# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49, +# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59, +# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52, +# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55, +# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55, +# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38, +# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38, +# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53, +# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49, +# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41, +# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43, +# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68, +# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47, +# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43, +# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54, +# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44, +# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69, +# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50, +# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33, +# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58, +# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47, +# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29 +# ] +# ) +# ) + +class DINOv2a(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2a, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + # self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False) + # self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False) + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + # feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1) + feature = feature.transpose(1, 2) + feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2) + return feature + + + +from torchvision.datasets import ImageFolder, ImageNet + +import os +import numpy as np + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +import math +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10): + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half + ) + args = t[..., None].float() * freqs[None, ...] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + +class Classifer(nn.Module): + def __init__(self, in_channels=192, hidden_size=256, num_classes=1000): + super(Classifer, self).__init__() + self.in_channels = in_channels + self.feature_x = nn.Sequential( + nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0), + nn.AdaptiveAvgPool2d(1), + ) + def forward(self, xt): + xt = xt[:, :self.in_channels] + score = self.feature_x(xt).squeeze(-1).squeeze(-1) + # score = (feature_xt).clamp(-5, 5) + score = torch.softmax(score, dim=1) + return score + + + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 64 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True) + dino = DINOv2a("dinov2_vitb14") + from accelerate import Accelerator + + accelerator = Accelerator() + + rank = accelerator.process_index + + classifer = Classifer(in_channels=32) + classifer.train() + optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001) + + dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer) + + # fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance" + # fake_file_names = os.listdir(fake_file_dir) + + for epoch in range(100): + for i, (true_images, true_labels, path_list) in enumerate(dataloader): + batch_size = true_images.shape[0] + true_labels = true_labels.to(accelerator.device) + true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000) + with torch.no_grad(): + true_dino_feature = dino(true_images) + # t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device) + # true_x_t = t * true_dino_feature + (1-t) * noise + + true_x_t = true_dino_feature + true_score = classifer(true_x_t) + + # ind = i % len(fake_file_names) + # fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind])) + # import pdb; pdb.set_trace() + # ind = torch.randint(0, 50, size=(4,)) + # fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :] + # fake_labels = fake_file['condition'].repeat(4) + # fake_score = classifer(fake_x_t) + + loss_true = -torch.log(true_score)*true_labels + loss = loss_true.sum()/batch_size + loss.backward() + optimizer.step() + optimizer.zero_grad() + + acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size + if accelerator.is_main_process: + print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item())) + if accelerator.is_main_process: + torch.save(classifer.state_dict(), f'{epoch}.pth') + + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/debug_env.sh b/tools/debug_env.sh new file mode 100644 index 0000000..d29dc46 --- /dev/null +++ b/tools/debug_env.sh @@ -0,0 +1,4 @@ +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat +pip3 install -r requirements.txt +git branch --set-upstream-to=origin/master master +git pull \ No newline at end of file diff --git a/tools/dino_scale.py b/tools/dino_scale.py new file mode 100644 index 0000000..439caaf --- /dev/null +++ b/tools/dino_scale.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + +from diffusers import AutoencoderKL + +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + # tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 4096 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0) + dino = DINOv2("dinov2_vitb14") + # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae") + # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai") + from accelerate import Accelerator + + accelerator = Accelerator() + dino, dataloader = accelerator.prepare(dino, dataloader) + rank = accelerator.process_index + + acc_mean = torch.zeros((768, ), device=accelerator.device) + acc_num = 0 + with torch.no_grad(): + for i, (images, labels, path_list) in enumerate(dataloader): + acc_num += len(images) + feature = dino(images) + stds = torch.std(feature, dim=[0, 2, 3]).tolist() + for std in stds: + print("{:.2f},".format(std), end='') + print() + means = torch.mean(feature, dim=[0, 2, 3]).tolist() + for mean in means: + print("{:.2f},".format(mean), end='') + break + + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/dino_scale2.py b/tools/dino_scale2.py new file mode 100644 index 0000000..336c196 --- /dev/null +++ b/tools/dino_scale2.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import timm +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torchvision.transforms import Normalize +import copy + +class DINOv2(nn.Module): + def __init__(self, weight_path:str): + super(DINOv2, self).__init__() + self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.precomputed_pos_embed = dict() + + def fetch_pos(self, h, w): + key = (h, w) + if key in self.precomputed_pos_embed: + return self.precomputed_pos_embed[key] + value = timm.layers.pos_embed.resample_abs_pos_embed( + self.pos_embed.data, [h, w], + ) + self.precomputed_pos_embed[key] = value + return value + + def forward(self, x): + b, c, h, w = x.shape + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w) + self.encoder.pos_embed.data = pos_embed_data + feature = self.encoder.forward_features(x)['x_norm_patchtokens'] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + return feature + +class MAE(nn.Module): + def __init__(self, model_id, weight_path:str): + super(MAE, self).__init__() + if os.path.isdir(weight_path): + weight_path = os.path.join(weight_path, "pytorch_model.bin") + self.encoder = timm.create_model( + model_id, + checkpoint_path=weight_path, + num_classes=0, + ) + self.pos_embed = copy.deepcopy(self.encoder.pos_embed) + self.encoder.head = torch.nn.Identity() + self.patch_size = self.encoder.patch_embed.patch_size + self.shifts = nn.Parameter(torch.tensor([0.0 + ]), requires_grad=False) + self.scales = nn.Parameter(torch.tensor([1.0 + ]), requires_grad=False) + + def forward(self, x): + x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x) + x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic') + b, c, h, w = x.shape + patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1] + feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:] + feature = feature.transpose(1, 2) + feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous() + feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1) + return feature + + +from diffusers import AutoencoderKL + +from torchvision.datasets import ImageFolder, ImageNet + +import cv2 +import os +import numpy as np +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor + +from PIL import Image +import torch +import torchvision.transforms as tvtf + +IMG_EXTENSIONS = ( + "*.png", + "*.JPEG", + "*.jpeg", + "*.jpg" +) + +class NewImageFolder(ImageFolder): + def __getitem__(self, item): + path, target = self.samples[item] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target, path + + +import time +class CenterCrop: + def __init__(self, size): + self.size = size + def __call__(self, image): + def center_crop_arr(pil_image, image_size): + """ + Center cropping implementation from ADM. + https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 + """ + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) + + return center_crop_arr(image, self.size) + +def save(obj, path): + dirname = os.path.dirname(path) + if not os.path.exists(dirname): + os.makedirs(dirname) + torch.save(obj, f'{path}') + +if __name__ == "__main__": + torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub' + os.environ["TORCH_HOME"] = torch_hub_dir + torch.hub.set_dir(torch_hub_dir) + + for split in ['train']: + train = split == 'train' + transforms = tvtf.Compose([ + CenterCrop(256), + tvtf.ToTensor(), + tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ]) + dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,) + B = 2048 + dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0) + dino = DINOv2("dinov2_vitb14") + # dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae") + # dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai") + from accelerate import Accelerator + + accelerator = Accelerator() + dino, dataloader = accelerator.prepare(dino, dataloader) + rank = accelerator.process_index + + with torch.no_grad(): + for i, (images, labels, path_list) in enumerate(dataloader): + feature = dino(images) + b, c, h, w = feature.shape + feature = feature.view(b, c, h*w).transpose(1, 2) + feature = feature.reshape(-1, c) + U, S, V = torch.pca_lowrank(feature, 64, ) + import pdb; pdb.set_trace() + feature = torch.matmul(feature, V) + break + accelerator.wait_for_everyone() \ No newline at end of file diff --git a/tools/dp.py b/tools/dp.py new file mode 100644 index 0000000..9f89d75 --- /dev/null +++ b/tools/dp.py @@ -0,0 +1,64 @@ +import matplotlib.pyplot as plt +print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250])) +print(len(list(range(0, 251, 5)))) +exit() +plt.plot() +plt.plot() +plt.show() +exit() + + + + +import torch + +num_steps = 10 +num_recompute_timesteps = 4 +sim = torch.randint(0, 100, (num_steps, num_steps)) +sim[:5, :5] = 100 +for i in range(num_steps): + sim[i, i] = 100 + +error_map = (100-sim).tolist() + + +# init +for i in range(1, num_steps): + for j in range(0, i): + error_map[i][j] = error_map[i-1][j] + error_map[i][j] + +C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)] +P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)] + +for i in range(1, num_steps+1): + C[1][i] = error_map[i-1][0] + P[1][i] = 0 + + +# dp +for step in range(2, num_recompute_timesteps+1): + for i in range(step, num_steps+1): + min_value = 99999 + min_index = -1 + for j in range(step-1, i): + value = C[step-1][j] + error_map[i-1][j] + if value < min_value: + min_value = value + min_index = j + C[step][i] = min_value + P[step][i] = min_index + +# trace back +tracback_end_index = num_steps +# min_value = 99999 +# for i in range(num_recompute_timesteps-1, num_steps): +# if C[-1][i] < min_value: +# min_value = C[-1][i] +# tracback_end_index = i + +timesteps = [tracback_end_index, ] +for i in range(num_recompute_timesteps, 0, -1): + idx = timesteps[-1] + timesteps.append(P[i][idx]) +timesteps.reverse() +print(timesteps) \ No newline at end of file diff --git a/tools/figures/base++.py b/tools/figures/base++.py new file mode 100644 index 0000000..9715321 --- /dev/null +++ b/tools/figures/base++.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib.pyplot as plt + +is_data = { + "4encoder8decoder":[46.01, 61.47, 69.73, 74.26], + "6encoder6decoder":[53.11, 71.04, 79.83, 83.85], + "8encoder4decoder":[54.06, 72.96, 80.49, 85.94], + "10encoder2decoder": [49.25, 67.59, 76.00, 81.12], +} + +fid_data = { + "4encoder8decoder":[31.40, 22.80, 20.13, 18.61], + "6encoder6decoder":[27.61, 20.42, 17.95, 16.86], + "8encoder4decoder":[27.12, 19.90, 17.78, 16.32], + "10encoder2decoder": [29.70, 21.75, 18.95, 17.65], +} + +sfid_data = { + "4encoder8decoder":[6.88, 6.44, 6.56, 6.56], + "6encoder4decoder":[6.83, 6.50, 6.49, 6.63], + "8encoder4decoder":[6.76, 6.70, 6.83, 6.63], + "10encoder2decoder": [6.81, 6.61, 6.53, 6.60], +} + +pr_data = { + "4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922], + "6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702], + "8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132], + "10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686], +} + +recall_data = { + "4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662], + "6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589], + "8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618], + "10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569], +} + +x = [100, 200, 300, 400] +# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] + +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] + +metric_data = { + "FID50K" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10) + plt.legend(fontsize="14") + plt.xticks([100, 150, 200, 250, 300, 350, 400]) + plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',) + plt.close() \ No newline at end of file diff --git a/tools/figures/base.py b/tools/figures/base.py new file mode 100644 index 0000000..46acd3c --- /dev/null +++ b/tools/figures/base.py @@ -0,0 +1,57 @@ +import numpy as np +import matplotlib.pyplot as plt + + + + +fid_data = { + "4encoder8decoder":[64.16, 48.04, 39.88, 35.41], + "6encoder4decoder":[67.71, 48.26, 39.30, 34.91], + "8encoder4decoder":[69.4, 49.7, 41.56, 36.76], +} + +sfid_data = { + "4encoder8decoder":[7.86, 7.48, 7.15, 7.07], + "6encoder4decoder":[8.54, 8.11, 7.40, 7.40], + "8encoder4decoder":[8.42, 8.27, 8.10, 7.69], +} + +is_data = { + "4encoder8decoder":[20.37, 29.41, 36.88, 41.32], + "6encoder4decoder":[20.04, 30.13, 38.17, 43.84], + "8encoder4decoder":[19.98, 29.54, 35.93, 42.025], +} + +pr_data = { + "4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271], + "6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266], + "8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162], +} + +recall_data = { + "4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338], + "6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378], + "8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333], +} + +x = [100, 200, 300, 400] +colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] +metric_data = { + "FID" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o") + plt.legend() + plt.xticks(x) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/tools/figures/cfg.py b/tools/figures/cfg.py new file mode 100644 index 0000000..4cd8855 --- /dev/null +++ b/tools/figures/cfg.py @@ -0,0 +1,32 @@ +import numpy as np +import matplotlib.pyplot as plt + +cfg_data = { + "[0, 1]":{ + "cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], + "FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13], + }, + "[0.2, 1]":{ + "cfg": [1.2, 1.4, 1.6, 1.8, 2.0], + "FID": [5.87, 4.44, 3.96, 4.01, 4.26] + }, + "[0.3, 1]":{ + "cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4], + "FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03] + }, + "[0.35, 1]":{ + "cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6], + "FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94] + } +} + +colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"] + +for i, (name, data) in enumerate(cfg_data.items()): + plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o") + +plt.title("Classifer-free guidance with intervals", weight="bold") +plt.ylabel("FID10K", weight="bold") +plt.xlabel("CFG values", weight="bold") +plt.legend() +plt.savefig("./output/cfg.pdf", bbox_inches="tight") \ No newline at end of file diff --git a/tools/figures/feat_vis.py b/tools/figures/feat_vis.py new file mode 100644 index 0000000..55f2045 --- /dev/null +++ b/tools/figures/feat_vis.py @@ -0,0 +1,42 @@ +import torch + +states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32) +states = states.permute(1, 2, 0, 3) +print(states.shape) +states = states.view(-1, 49, 1152) +states = torch.nn.functional.normalize(states, dim=-1) +sim = torch.bmm(states, states.transpose(1, 2)) +mean_sim = torch.mean(sim, dim=0, keepdim=False) + +mean_sim = mean_sim.numpy() +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +timesteps = np.linspace(0, 1, 5) +# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False}) +cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"]) +plt.imshow(mean_sim, cmap="inferno") +plt.xticks([]) +plt.yticks([]) +# plt.show() +plt.colorbar() +plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight") +# cos_sim = torch.nn.functional.cosine_similarity(states, states) + + +# for i in range(49): +# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1]) +# cos_sim = cos_sim.min() +# print(cos_sim) +# state = torch.max(states, dim=-1)[1] +# # state = torch.softmax(state, dim=-1) +# state = state.view(-1, 16, 16) +# +# state = state.numpy() +# +# import numpy as np +# import matplotlib.pyplot as plt +# for i in range(0, 49): +# print(i) +# plt.imshow(state[i]) +# plt.savefig("./output2/{}.png".format(i)) \ No newline at end of file diff --git a/tools/figures/large++.py b/tools/figures/large++.py new file mode 100644 index 0000000..070d13f --- /dev/null +++ b/tools/figures/large++.py @@ -0,0 +1,63 @@ +import numpy as np +import matplotlib.pyplot as plt + +is_data = { + "10encoder14decoder":[80.48, 104.48, 113.01, 117.29], + "12encoder12decoder":[85.52, 109.91, 118.18, 121.77], + "16encoder8decoder":[92.72, 116.30, 124.32, 126.37], + "20encoder4decoder":[94.95, 117.84, 125.66, 128.30], +} + +fid_data = { + "10encoder14decoder":[15.17, 10.40, 9.32, 8.66], + "12encoder12decoder":[13.79, 9.67, 8.64, 8.21], + "16encoder8decoder":[12.41, 8.99, 8.18, 8.03], + "20encoder4decoder":[12.04, 8.94, 8.03, 7.98], +} + +sfid_data = { + "10encoder14decoder":[5.49, 5.00, 5.09, 5.14], + "12encoder12decoder":[5.37, 5.01, 5.07, 5.09], + "16encoder8decoder":[5.43, 5.11, 5.20, 5.31], + "20encoder4decoder":[5.36, 5.23, 5.21, 5.50], +} + +pr_data = { + "10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104], + "12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823], + "16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912], + "20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098], +} + +recall_data = { + "10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679], + "12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693], + "16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773], + "20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711], +} + +x = [100, 200, 300, 400] +# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"] +colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"] + +metric_data = { + "FID50K" : fid_data, + # "SFID" : sfid_data, + "InceptionScore" : is_data, + "Precision" : pr_data, + "Recall" : recall_data, +} + +for key, data in metric_data.items(): + # plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False}) + for i, (name, v) in enumerate(data.items()): + name = name.replace("encoder", "En") + name = name.replace("decoder", "De") + plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8) + plt.legend(fontsize="14") + plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5) + plt.xticks([100, 150, 200, 250, 300, 350, 400]) + plt.ylabel(key, weight="bold") + plt.xlabel("Training iterations(K steps)", weight="bold") + plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight') + plt.close() \ No newline at end of file diff --git a/tools/figures/log_snr.py b/tools/figures/log_snr.py new file mode 100644 index 0000000..f9d1b28 --- /dev/null +++ b/tools/figures/log_snr.py @@ -0,0 +1,18 @@ +import numpy as np +import matplotlib.pyplot as plt + +t = np.linspace(0.001, 0.999, 100) +def snr(t): + return np.log((1-t)/t) +def pds(t): + return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0) +print(pds(t)) +plt.figure(figsize=(16, 4)) +plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o") +# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o") +plt.ylabel("log-SNR", weight="bold") +plt.xlabel("Timesteps", weight="bold") +plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0]) +plt.gca().invert_xaxis() +plt.show() +# plt.savefig("output/logsnr.pdf", bbox_inches='tight') \ No newline at end of file diff --git a/tools/figures/output/base++_FID.pdf b/tools/figures/output/base++_FID.pdf new file mode 100644 index 0000000000000000000000000000000000000000..4aa3a98a4deb58dbe270fe4a5d97752aab9726cc GIT binary patch literal 17051 zcmd^n2{_fw+kYzJSjtwZ=t#CKXWwMsg|d?^$1WT?PSU0%LYA^dA+m-LkyK=@$WpQ; z5=BLeCDP)Z@4=Jsd!PRA_5WYj|GM7OGiJ_w@44rmd+zVdJ)e8dBdn*YE`^fD!h~~s z;5m0-C^!=CVS5lJD+@;$9rbj8BUH#VvYUq!9HB>ca`1wq0f7NrULNK^v4;qWiyEkV zP-t)rod99B&-S2$9Su&H`>N(c(=eowX%2AgT!kK)MsuK2;5g_j3}NU@ws)a8!SVA? zsUCKQ4m7w0=vGx7h~nTwgCn%ufe6a;xypR51-Dux4r=)TAh{vwsSZH8Iq?w&4qhII zsdivIko=&3T?cy?va*K{=n)D2VUbvA5(sDM!dR$g$@ zoUBR|iU;_Hg4!9terpDLyr4=Q2Z|HT8IGFoUfsnF3?7b9cLN$xb+GfWcYwy=MWZ^9 z-C;ggo|x-;Z6k2po5~-)c9L1`^s9B@U%L;vmwbJtYD{wbdclAv-=O5?udKV0XnU@oZ6D`ArNGavrF%DnZqcI_RwDho;Fsa`)>WlQkY){`sEq ztpC-J9)zAvLsDu#>BQHgJ`sJsKk8Hx?CaHTJ2zd?;Arv^yl`le;^RxQJfSW0?aNrl zqrA9^%0bb=pN5-{eS7(NjOruzZpze4b1;!^n@4?%XFWW-^hj*QIWM z!e@se&037ZiXvn>?lPYhofL zvAv!95a;u-JqK`jpUVwpw;(=mJO;9z8cX92Hy4%^d#6r||dprrlB6hlL=hq!h8;=J|U=7?K?nsg=;ohH@ zC%_ruBbnl?$ttkb7ax)Mu+C2OP*}Z;hTit861!ZnV?wE%q_og+5rnhbbQjip$JVgK z8r&A{5c9peqE(-6i5TWs|u(1Sqhq|i^Hkmq8q^*TPP^*UKv z(aOw=yVw%ZqHGX(w);??CtFX)H{-_yV}am?Eg6X$wd;xv-k$nCwdqvdChqOBzJxFO z_xi1DLb7)@J?G=o5Yf;-AQye2<>jgHJ_Nt5HF*cC73|xvE|+gvbXjylb4Gf;$LJo# zlicrJ``gW826w6_S2g(wm{Zz5s!V-yz&+c~nR7 zr9%TL_bms1lIi`4--$P2~g`4TYA3e2SjPHe>lHc}8+@hiAnjE8q&G?O+XtPfa zNAFDd-DBOL6Q#gZK|SP*`=T!Q`HgnX{dSx7*9H&V5)~ajeHaaHFPE!a9etsr>|-Fa zXhdibO1NM>)6KdKE(bARbNY)FAIQG^@bu}3%t#}Hg<}3P{r++mH z;>fv?Dp0L&qV5`xZ{XP0sFxU_&iLFl_@(AHbG67!Yq0^%y$MeZ7>zoI$JR{2wqJHE zaK;iva=1y)y2Wj-v}C(lwo6LXZ;P&b5Y6g4bMVQHh~hW}o$n(L)^_PSzq~4Wt%4I#s7dXRlv1O=(p{VP5L4MERX1Nwh)16zNcVETp`NhnH zc)k0$)QdT!Fa=hh)9<-D4W%|B&yXX<4Sb2+J(PqNRN zSbWHHG$r0tX#DEYARM~>*>qIAk#+CmY@Hc9cZY_Z6L(2e^S0CMx4P_G^Lu2rSY?^9 zTMXiQ)Wuz-3WdKq@y_hdu&hXH8{)a5sDDuVB{iUn<;>F;Mc4OJib7ibht8bIY1p-w zqkn6eLhfzfxMDp^0h3VmXuFeZoZRi~UW`ONr4?{G*OxfoqFT63nHTF6$ZWT8DzG_B zbj*z+dA)G4ktI*w@^vj~<(IwCPqqjot9Jn{+6^A)r6$|wpPOA1o@aRk^8DsKm3beK z?$$uA4jPRlEI5XAe+a6yfLkrpp%;5+^F9$8g#nJxzc@s4db@9~K{0Nr1 z`B7SDPeqvLn01k})P=QWA|iXf1#6z&D1gR_dNcK|3l{%aDCzr<#&4g!VbjR3{3=`t zVi?x6B|6nA?vrCsE}yp<=1SMZ{XpT1qB^5KCHp!)DAdYX`wJz8`5hFf**t#U93G7l z+dh6Nrn2ajL_cVCU<_&xSQmnH&wuW;S4#HF_EzTtH&_c~K7OP!2v@?F=| z3xFy2$fLtTG3TG1A)BJ&$0eOi!^hKoK3bsZXB^n-k!m!ldxQkXvjNU;xQ7i z|E?TN5z+*ySAbp3l@S{Yd#fG`6A!dhln;t&VB~IKDAJK?^pg}~U9&SbI9pJ2V}OI3 z4NHt#0o!S8XTr8i??{$>mM*JsRhvJoAQ9xuFOdj`{0oWj1Op(^T2A$iOv?oVJ)603 z=%{g+GMR2!b(_d8yh6{=F=0v1#J|u5?uj;6C35YmmV1W8Gn#9s-qG%8eyWgCC&<=m z7++{1ow<1R<^T@>4WIaz+ zt>c?qK{n{jvP3rgUl=i%C!iH1qb5GCr-^Yr3pB>}){#C-xqOG{XbcCqic=Ms?*20m) zYCRi!PrNvhzs~f%n*Q#?&t#(n2@lIQzMv>!9c|H?Bd|8f`8 z@j4~lw`!axdMi%!btG$}pRV~B!bEz0^0Rl$xLyObwRzXDyS`^BVj}lw6X$-gRacu_ z%(O_TbKM=6P)Zq^V6ov4-9l6u*>Rlc5T!IQ|=opbxSl+$60{O`1K|6vv(1Rp7ZkIcRh!OwjbRf zlYqvd{$djJNtCrHwhupHzBUXWihm}toZ%^&^-yt%%+u6FtF@NlMS4DSsdTFj*2IOT z-dX+WoXabL>tkewE{;hS7JodB{zig2|f|+_RIh_&c=B>=0dM+n; zGE?e{0?9LgruotqbG34C)G_-{hdk**FifN`&+%1mK zpT4h#b-qzxEgN&GNc0T%{mJH=n42L_3_fok#Q6rUbG0D$e|$l2HW`W z1t*DCuV^mQM|ocBqv8E;N?yMkdG8go5344E*kP*5g%Z?iemCmymTS{n$E!+Kx0rKt zJQSmTv2>bo4GmiV!GJ=GPl%b0(y^E7Al1bt_x#A ztmaqic^ab^_M9(WqtLsYF{P98-6i)_#A{;{^{>ySzldxsw8jFkHvq^9kC zaen*G%$n?39*yjnSMn_3Gz~qYe!d>=*C(+$!#93ct(~bJ1I5n-zC>R^C$) zd5L@72g(%~T-el|S5U+M)SbfVm9eh@i;rBuMrND$MIK44XRuurVPc}SezKQSHHWOx zwYmaGts}u##%c;uDT2vlOv1ZeMnUR9Rw7d^GPiQKO11glnaP;V(DRc!&XjUjd{Y*y z6$A6shEHUR%STTd-R>t0_<4V6^L8JRv+vPijPiVIe~tCo)0lU&CqI7qNXh-iM=9SQ zEfboV-fexoX4lW7Tn+wCqI}+o(W_gMwCX}Rq$dXLsH}xD^=^SNJc^@igU^{Jrza$) zV{R6GdB4giS#z6{W;bf0f=S}u{T+cuHY;dn1$PRK{fnhp?ZR$oX;wL&d}fnbu;X6aNV@G`d!7iTPeFQ zhg-W2uVZ?NCwoH}bHKh)fvK)zVOeH(9wuB8bgIgk-SRFbr`2s^*L7`@xae{JtvfP3 z(0nF>wl;5bqNUng3*y4J^q(9d(W5P?X~+frwzptY_4C(JT@r8v;9yp8`%t*QSdir+c5_q%&*DhpnN7?#KS#%JGs~n+ z9gq-qUZ=8FSE2Q7upb}Mj4vU0B4CqY!;9}1vSi10Bnh89{X~#qV_iaLG1nD>M(A-d z*Dk4y!!>uGt&TI6w}1y;J#>-vUdj|@26tQYIcJd!d!CzLW|c&#mU0L|P&y#O>E@)a zc%e+pE3>J~Z^EQB54*qZ51Adp-+T9C^-U#Y(h9b*f_sVjt8IH9iDNZt!-sL$sa4`h zLiNWLgKWQ@Ixor`;whsJVI1{tvfO50@qkC4w>b?dK^_V=4OhySl@FqHH&FV)%=j9&mgg4g11i# zp?W!hV(aH`Fk0#G zxkPrM7-Qa<*i)O5q^6B&uL=rX3NvKv4zo&b-NTw3;4ie>mSUQt3lr-zo=Ntr4Zq=V zjq8$^^5(`zQR>$}Oc|cG&T@~@hpBJ~2a8_E4CUq44E$nxTT*f0c6W}Pf`>a%a)n{8 z<7n*BZ0)*>mWKCf zY1h;CeS6pVDYiLilpIUc@MlDta0pyyS!y=Zf`ZRRnq87b%qCBGw2G>$U7Ha z+kfbd!dFt0-V6k0@ft5*Q@2^u;PA4L!5bZg>TUMB&6&z&L zUl<|=U#N#XUCpBcoB#p7HO4D&UcqfdVbFhJ;}J?6DuQ70ubdHLwrw5n#9DQJ_t^b# zw(l&lIdm`6cvC?7>w@;q^J5m8a$icb>rp{J)ad$tH&l$rLXsWe+!dXt-1_nu;@- z3P-{aI^_8>8ixaqK&VR9!OM>7;z{$M!hznQ;2py7uq}-q+k-m7zz}=M?jS<;ANd0& z%5$MHDHM``giB$;DHBB^fIlb#4xl^qessw+stY}SCXGZwjtThpPYDW}0SWa$SPz5{ z=OSeYH3|rPK|#OY!N!Fu7(&&>(a`~f?4TH(1ss}IFHf=^2pcDd+kO2Z0FbdFy z0!MfNcHj#j2cbDHID!U-2BLBZQx|)hGYIg_g))B!_~`0f@Irsj*&>D074t9Jbt5}@ zf$_~L@xRqG9HHc82L&-nXc8DIm{;hF6dD5-JQcF1rUP9KzzED1ia-ewB)Y1YNbDQ#=rku9gMqpqy?((D<;RB?^Tn(3^lUEtn6CJ|jS*P;+|1kmzIkht*sG z5)fj6Bmv2w;X_gZDHkjOvH<$T0gH!nNL>qp%yGa&L&DSb2U)^D{2?kxso41hzQYI* z7my&~py$vt+&lv?IZ*3`Kae%h4FNKSB?*Nj;s7tu@w{a$@A#$3r0+6w4Te{TZx;|jRLv1l&<(?KVq!od{M zOJI$i?{sb-0k0^%9$Z+N>-hH?2wtY;=mIc|e~(gJc)$rc7N1NCg-@SWlzU`PovxU5 zRgAK<=eL>4-tBim{~V^;>GdQ3XN6c#58hB$=6s=%bRX9ymWF{fGbkZ>!qq0 zM!5TI+JTt{Wc4Mphb8$0(;L+L8+*2$UqP@HjxcB}*!cb;Rurx($o7F0fcU&~OvqC} zJx9v%Q{Rm>3F3M>z9pEQEu4NDY^dn`0pYx>H~cEjD}SN9{@`B)&rvaeHyv9aL>tM# zU*wkIzk*R96YRn()lf`V-qAAT)`^h5SHaGDZFLwi(G(57z9alSCR{ve*YRYZjgq4e z@fMB-TL72F?3!(<#=*Gh7gy|U-lB{r068HDmhqu&c;c<(rwEBo-8 z%&@&~Q?5{a*Ry>u;-|NHmLBM~-y2g;6KI^iQrzTE;=MC95R6}^8`eEMGR5t(3t>>s zZ=UFTm@DSx=Aq48aZacB`s`=Ft^0M(e8&nJTnQqs30z8?Xo&KYo}(GamE^O=q>+Ho;QNsO+v;v--W_D%Fu_qr3~@ zPCL0Pcwai%ZnGI`y~%Kc&qy`)_(*+Fc2id9UDX=-eY!JZu{u&FbreIs13#ie+fL9J zcV%P8qQ_c{i&|LCf{${2CH>Hx9%^?J43wd?j|IQkj;Td+HQ8vBY*OL|G+w>ff0*+I zf7$IOn&0-Ac2Aw+555m4_x+;1Fn@CRE|ob|@r|z1=W>o4>V-8mX!*o2@nfly1zQsC zs$(9IH*mI3RX<#_@0_ zvrF(y_np9l9~8Zd?vHzxBr9NsQ{7jX(*^$w@|hPRsDJZT=;~{QtUidy(?ReLt8Z?WjK<@+(fga$jIuvsH@E4?-l~88 zj6`%h3oX^8=6d6+qPOkd1Qn^VS6L^&%s89G*#xv|Pn!802#~*&9q;`yQu$cF?(Q9p z029?ju;Qr;@XWF{Y1P`A%S0en9@l)PfC%Yf-@Dv=^;M&8d zP0|LOqYX}$)w;yczBuE(Bv|vf`5ifqMPK8)xlKVRF1H)?NEsiUQwgSSA0iN!^%7z0= zC?ksbMYH&A#q`hnn#QkRKYt%Aq1^X)b&xg#;)wlEwePnHzeukhwAel+T&Ua|F`Ce5 zP&>7;h^JX9U!$3`R0_6MpIgo>9>Ke#t$AWtCd+eAvY#n)DbLos{#vX24^FQWNXNkS zh7{|{ZP5m6&g|^;ayi5tOce<}m3B9&Pc$zrWkPO$d_!N9c6tlH?%rXB+Q+q8SFNR6;!^uMBiNMN8S$PfAqthe!VE0r;ks*O};l>&snS%P&cO4U+8xPx537iuT zHDxWz5^x){;EqY-&l}+DLiRQw4Q*5U3mDtp+B%&1_+|i?UllZ}A7xeT`7!p%5$0Cs zUt2dmer3lSd0Hd0kc-`gSQ%8AWKFEK9_J=;R0i24CP=c`Ig4Z-3lFYrNh(Sw*VPM{ zPmZrQdUxX73KOvcn2&+5fPeD_a0Ux;)|2qlloy}6;F=eIRIq1v@vJ#SGl<@8Zsgr? zF2Ly42<=;OOg9%{@Q6b|4a=h(u2F+P&8)j7g3mw2qw`B~PHP_;q*o254IX-QMxi=A zlePX^Nn!q*hx`!&X1qcvs+H*4n&B<6O&<#!Eqr5LDkFT(UXf^*syAaa*!q$<8Zwn~ zRO%*E(7pKB5E#En@J3&yFIP3fPJZ2NWCm{?-KOOk8mxEmw6|ZkCL8Ivq__U|35*zu zd$@b6T91;SaO5RY?H$@g-*iji)nA>VgbJBm?<~ys+dY-vIvQOkiGFv?Y=So4UhmI} z{gnD+=s6n~)*&piZe3nCDKIr&-^8a>C5rtCs$OuMBdMB?WzUQ(Pep_-_fz1fAsVU;W!W) zSoQ+{RS;ZCUyHg%iMhO4V&bf0W}jW0#JR!Q4(7+}nAp!N(#&694NZnin%;eLF~_;O z4}&_>$9T?HqQl*6P|*mkWv7Qr@&0mbLZI`|C6>mh_5wD?-1R>~#e0Tp`%!#Or5z4! z8%C`1<>Hkd4rM7kkMO;6v(WapVgWlPwsoaBP64V8HduX-#XD zxI%BJhvBI_-_sEhv_RvqEYygXl>32ZU7~Bth*|E~V6=eMA!*S~h}>-xllV28^j&ht z6g4Y2_G8Tf8;hV7 z415JY91V`J|8C&=Bo46Pk4*9YQrsDe@Y$Ct_T*Q^lh9TZfd`g>#nFQCV+K)Kf^Q7v zkHt%0xzp3pcn~`jTP)>u`gX}(L|$ZNuw#(_1D8v;84_d}R61I-^i?(qw6vugiWT&% zl@BSUtbb{D%P+1(a3~mVE5;|O^dwu2d!u`^Xd14<`ux`3xY3sn4b$slH~TR$m=L#h z40n8vk=rNIO5;Vh<=`4Nvv8~Jko^{(@tpgEIrb@#Le90!Sg{{ez7HcWDte zItRBcg6hCpLO&1wiH*YNvC(C$m%?@63@n8O`!k9Fd3Bj-bJuC`a1d9UL z21o$=KY>NT!5?wVfeL_24+=nVX-7bb=1g?}Wk>M)2JqL#8}oEV_q-@~Dg z4}xz0KY%#_jEn^{fyK{3Q~;8Kks$C8U`}WN`C!1k1}H5+SQtQ$!J_~o0&;+2KoHh{ zlVA(aAX+HR!BBHZCl0_Wbc_rSKyo}0Le&7)1kg%6Kv{4Q@CYRU{h&!W7)SuT13)AM zPy$mY0P7<_FcJi0LOJvtz%6+28yo^?5Ah@s0d7T{tAld-U@!nR1p|Tvfd&O7Jh%i3 zwZ(yf&+!1XpkI0!fH`OYZ9xe~AklLY0O9Zu<_7SxIT{E$MA5Ng3<+w71&|noqe0LZ z7z>mjKoY|H=9NZ=yr8}z1%ezX6bAK)18x2Q4j}&g00VW<*K`U%;OA(dg_?fg0etA|@;tCPw`>FRqjQ2*a=KRqsOSqkZVu2~ zw1NS`1sT-=WC;@&5J2D&vt&#{O^nWIk20O`b zLd<_81e?omGVFgOLrQ~93=l%Oy!@gydJ+Ejdn}6}t>2Z4F%ezpU7b^+E_gZS)P~Rn z8k(OnU9d^d&DU=R^cD-+oZHhRUrIF0-m#n&@%hVIZ3_kV+8mDpQq$}AdBq=a%Z}90RZk{&Rx_`^E z`&+lc#Lr>=-0Y9m-36Uc5|2b58F_`Mx4)3+0!xW zL_?1sM;eZs|aAnf|h9LMuW#;5PzpHnDak`0e3EbugWUm$Gp(33o&p6 z@O1N_xw+WFy)n`xX|yEVnMU*Ul0hKc|9K?sL3I*`A;1j@`@?qNx50~BdD=U|ZOL}7 zpqF_LfRfJM#e<>>{is?@RR)bjYyOq9_t6D!ynv^`SaA!hpkphk>~E z{0BI6II#UW!vCp*)<)<9r_9%(aL|Uiqz;P)q1UB#7$n&3pwI8NSS^wK&U*oc?a zK|ws|^G{n)Ue*T^iGjA)Kk0xd;BfM19TL4v79<)E1s4CHLt$~iu~}M21P<|R=jgSN&cbR^`mwpcV2qWzN&i-&xoC3OVg`7Wu$At7((PdXe1I)yE%!>>R` z1Wx}FIy`iHVrd;_nVlj*ynHzw5%TE%=of`SF0(5Xnh0S6f6}3$+Zjvhu)qUfQimjx zmg|{>hHeo3!4FBoET3QClP{f1a8d`S!#~?XVgKcH_+@ql?k6m>Kj`wwQk{Z}Prxx> z(g%tNdAfhdj6$Qp!C^@q4i8|Df9hydvWuGomHt1GFmyQv9A4lu8hCg>izVH?)uK28 d{}ynb!x+41WGaomJb|?fhk^+UtL;;V{TIEim3sgH literal 0 HcmV?d00001 diff --git a/tools/figures/output/base++_FID50K.pdf b/tools/figures/output/base++_FID50K.pdf new file mode 100644 index 0000000000000000000000000000000000000000..779138f7f20edab9fe87708b557c18802ff7f120 GIT binary patch literal 17760 zcmd^n2{e`8*ME|sE6S84T*;K-&iA@RnMuk_q0B?L#7%}$ilPk3Oo)sj^H3C-B1%d^ zq9SCdOi`5gJQv@D-}~+Tum69o|61>@(>c!>_t|GZXYcdb=Se_QQCS=%frSZVcf+&p z!%%P}+{yeTY~MaOLf6CD3XV`PBbzxm*}xH+W;Rx?a5Si(4VRULS&=NEijpfDC_9nJ za15ma!tjXsNh=F7oUr()rkkY9=wmSQ!y$qI5%YbOBF4N>o61*lsjAE9mK>U7%0 z0?Y@ZAB?YIWoc(7@8k|fM1nt9BvyimLgLXlI0l85K;w~U6c$B9fw2$@V3wdOS2%Hz zRyh*M2|S^ocG|GtoPmm$S$W8cWJ9)vqn3tOwsQcJha;3707n$9ESxN@pgFjbU98L; zVeU7d9@B8$MPPa`lRI?#9KF(ocXYvSqhW-{7Qb zpXhuA=RVm7T(#G9CIa|n9~Sdc=lAu;Y%2S*|CiSeZ!U=hxn`9^UO3Z)YxmD$L9k}`|2=vuWI2s>c%I( zwxA^IU1MZlS1w%as{A05DWoCGK_x2Vgt}MFr}+60I^)%;Svz-`hmW6_sWis@9H#NT zbn9w2LesP^BehkY4ZkdWBfa$4X_k&u*$PJ#*!|erlTJ?oBj4tHS?%e4?~5 zC#tf#U#S15&JNG-Kfg}6xXXN+F>qDwkEh+fIsWN(*{C|k-3A_ejz|dWPi1O$DqEg) z->pMFE=>C=2miHGu?OK2->}1EN-UbF=Xuo$S^lt##4*n-&^Qz{7DBTIr)WEu@Lwe`i=0#JdaK`%s{37+ zBQ56x+%F`>SeC}u?D@ZAUz)>7sqdi;XV+FqNA6(dYId<)wH)in)#aw$(%f2%cZ^_)k^Clp>T|dNM zWnZKwHiagHaev(7!lrDL)lujwil9uG*{R)%WpBnYp9c z>?8gh;et|Zfz2+pERUc5!JK#1eEGCE_JI+2zzC1Q2q8m1LVq#WoeVFw5Zsq+A=tL( zAfNBbA%awF`9##@$IA32X+Mw5@BJ9JH>B1oCqi23i?1uY=lvH8pPdlN3-1@k8){)8 zy+^Wj6WuwY&G!blo>r@W)Tq^FuwQ_eJd@ZLqFx%!5!0X)B`@ol=zBuKN+Q*|*lu*B zyqrbzqwf!wvSI@PVnAzpp6l%^kzB=dnh)mhami!{^+*?bn&B^{&tyLh&pOjr*ZXt! zLx_j1_=l>`pLuJ}-!clbRU=w_$Wsr8XT-*~e6rnX4J$|Bgtec)n`sRpK zo%&C>Sxfi*HwvfuOFk4iODhQ-RL>n8IuIVhkp6fhmyAo5)VVNH;mbEeW{y*K=`EKv z#l?gY>|TxMAyxMT_>YORUX(|O(5u%DMAGDCnOf>bq*%{$oT@<31rOuJ6p3ozSYLLW zKjWJtouDN%P;2)7%vSd6jEW48=~U$GH=POQP`WtqeY*MCcpN#7yH|Ou{p};tzfAq3 zu1^NrJZu^F)A7=B?)8OdY&|2BAI{iEe7CJitwA6_Z|It20+Wry1_hiOr{Yk!P`MCW z#z^pyj_caZDxAE-*I6W|6CPW`k`*~b5*xn@cihPtB#02a9tQEWD<^)Nci1fW1tXU2 znNWDG$9#0StPRaUqj;U|+ZUV3_pmqbaPRUS z^*2A+zIAwTB2Kg4>gmHN;gYPk9a)b)?7YdNJLc@L*&p`4^=`%|)96>ZNe1?7-^ciO1CM&0dOJFUyL`vxU~bB815C*Qjt={%OeJtksE6?euY zwboO*>E)}fqb3(jW9Lz`5!V#50tmxCh!eZN71I`OkAhi`Y})H};)`Zc&sVDLw|py} zohB$sdRQt3iujyEUdLTl)hWzB!I(}18~EfF6>0tY6UM&*uEa`k+s$lyQ1YS|4XkZ% zaiKif@HG4Mj^G0Hv8JM-m{ml;P7seIR?~WhC zi1)}TV6hJB$mZui1H(8YkN2eMaDE%Se=E`~cMX`467rv-X%<_@DZ5$3N8n&aRim z`wRBm<_($7KO!oUk?S67J9?w@Zs%4{!`^xmuGhQU4myl?=k_M|p7e5vdX$oLzF9}; zly#9&?^mbd;+~Uke*AMa4htp5)w9T;{V{4EKjz$glCuA`e#s&EmN@=&CB_cTI1RwT zoZGE_S6}5R9j|c=ztPTOWPV;hsnX)fbk@OhhaYWbOD$;=H;9gh3qO#_iv5Vs+9BgE zAzXGp`$PY>hoN&dS4Rj38R|Zo)!8Z%+X>F?cznrISrH97E+03CYx<(4b#A#zO`EbN z^H0+<=wGp@mx{rZ#?#ib>Yo@XwdQ*Jh?4qfzK(>yTm;L77HyljS`0N)v6#FhZj+z#94&t4Kyux_}x*KXMMn z{m8U=a9>bso>jTLgE?V$@0`J6B2-9YBH%-qH? z`YX2@SArOVJ=+@lx!gr$Yu~YPAc-HEN?Z<*2Fy`AmBcKlp)DMY9w`KMY(QSUvkz`zI zqg%jvM^)1gCf_ZK4hqCvdwJ2!02MPSW@8XMnd<)ec*@NF(#nPHqb7;GY0QUW*Rl}} znTV^*MgsQVnS&`r>LWGtup3#@A|qh~%8@XU%f<&}{Ue%aSevK{4~aMXi19N}?}-e^ zd1X=W}D8eECf+9RY8&I@~S$R9{ zT7tmHhIUMcl$Z=?4R&rQlVlWFXJqKCuxez3 zrBYm(u&-W4@A`A%#n4*?eK`1V8-s(cRBEw}swMeIX=yw1W{JFqpUc5am(TfjqSDi& z^ZS|hIiIhl;fkEN_eRX7Rs0a5WwvV!LX8+^y!!T!gvTAya3Ph@kD>o+yQT zA6Z_;4Oc~d?i1)Fb(_r1@A@F_OXv~Zd-LSC(-VyMy0fyaZ}J|zCL3JgJUGbfv5kCT zokf5SxvLfdjYIv#B4`mwn^0Rue!{#>sYZ%^CTzaQR=D7#V0R-&RTZt&dJiwy{gp+b zOR>K;Dlqx}#xGaw-tpX-Fr(^Zn%=zmlC^xZGo#E#WzGs9xDX_p<` zpZ7{8^pt-Mr)bBBBe0F0wwHZYjVNTz#tjmA4<~eVNX_ zV(=&O^1zno!H+zmOHU_S-ptoCN%+$0&d6$)HZ}0^YNW}*qb3wI?A|!=ZC9OS6)k=FyG6Lh3G-MNCs~<&zn1*?=}DMFz|&pERF6s z!+@;&2@+l`vii*Kiy8HC4>m>OA`iJ}j!iMfD4z08>iFO@#FWl^I9w(x`pG+u6S(a1 z6Eq?h^CaZ-=!-v+Hw=hbCx*W{6tA;q*M}!1H2I%jCLn0Wj81(Pxzzpg8&dt@>kFsC z-G$yAjLjsYc=JIsc;YHUEH#cRvkAVdw>63@%%+Ys-ji?0Z{-Q$4fKjD-_l z-m9A+S^JyhqhQ5cKlcxwnb|RW-&-eZ+%~B(&llU_GOP9^EN9EHW601{;fGBp!yTlG zm45j3cva=g1@h4a(IPf-D|Nq!&RY=$7zia=wN?C;ddH(6^^T-sQBDKu-I|u7-XCTL+>@8!P~B4d zgsO0bM)lwPf>vvbxg1*@XRJcoBhGF2w%o;T*b~0JejSR|F~?BEztHbW*VS(Tb4+1I z)-U*i+^lQD;Z$qXy#+k_yZC6jJPC&^ugS}>-xONoJTlhcUv)eDeI@F3RAcGa=2Dm| zGF>(~+MjIfvJu7DubjP&y}T)4bL`%S2a2m?$?c~OZT9XtdJm;Cev)}k?~{wvdrxNj zhVi8NVU^Kx_meIuhu>V|-IHFMxxl898Szebb1+#&Q#bL)WcUlkqHM;{2kym1ub&0S zu|QpzB2Wek!{y%C_sUVI8~cudGs><~>Ggu!qa|ZKpRKE_6y@ zxQOkn3=$Tq^Vt`dtU7ocU8li=R67%JbD}mcnZ%c9hKc>OSJz+J-$Zccxm0mBmw4Ny z`#;ha(lmW!yl9guL^v`SOsME*wtX=(it{+9TlSvN=i@fs=H@smW7&O(Cd_%j@;1ZE z7ZINp&V3&LOv?VwNh&`YE)|%b+I9R+?cSdrEKQeeggD*e!#6%pP^%AQl9=kZaAC-o zdhBpHg6+WAmi||?)3Z~evk?V_IDrrR@JluU**K{53tYZVBv45fA zl&(ud8cxA0@gj#|)sHUOmQB~#I_FB6ajYr3snq%lqxJ}v^rgD;MZsM@s9&ysBsE*~ z>3z$rkv+%ZnopIJ4yBCVOb+db*sW7boEe*f=>3k2@yv9324xuDZ-|ZNv#HKn*!d|U ztJPt9=N)yTh>+JMuHETQXij}TbJKyWaPc<#yr|$^xwc6CZMuw4AAGHt-`QKEUv+H& zlRf*XZ1`1bdCwWuq@xlKTCycG?{~BxzE*;_RPR-t)e(+0y>mU?vtN+SQN^0%{6JvF5d&Qp=RkFR4VV%D)?QMkX@kYz%onQDN6 zb|(HJ2fgXfvB@%esg#)$q5`&b3Y#?cw+;mOa7r3-#s*CJap*L?nY*5`Z(?_Xz_|-g z`KY$n$95F4+$5+3dI{Tiil?2ft$4XHN>A1Ze)-m^P=*IdGo&B5GSyeig;I<;4nFDC zqNQr`R|$L)ejzpm(;6cAQW5VAXW~8viL0J=9C&|qVG#e|)BMH)Ib^~*vayaOjQT6< z=LnH$BWl~oB<%bKkp%w7UI+Zmzn{M*M1R#*aGYJ`FgZJzEB>ombGLim&xD>J>$22K zQp(P>ZXbHesCRL-vu<}w!}#8GWMt!a;a0mJq0DoeA#~6%``weU+6e~(Oz7k8g5(U^ z8O39#jNSVsE2r3d)ewqtek60}Z?`+MHpFyXc&h6ce2|#)dBNMjt9X!Z%n{kP&ivLf zu~9h5f4AqCuo9+Fe;q#XJFsrgG}i~ktQNtTBKVtWK_2p*Sn`I&XP!*+bYNXgcG z8*Mgk^{T7d^7@Xj%z$~=)n~!p5B5elddBEA)Qg??S)aMbX~JKjT>|43%q3>FK|Qbg zU5%wsOQw0cC(GUG0?XLX!ChtXNrH}Q#{>FQ-#x_olpDOUFp#e1pZr1_&VgYo$xXSn z@111&uA`{|(#h4|F(iNU>l<^d1ACIY1X0JB9g`GQHyQwxF#cZR=dZ>UNA=7&Pl|XS zsO|{%zpdRFoR)JXxUL)D-dSj-(rMw!pRuPX|Jqf-J?vRezM9*axpdat7r6bS@Zmwn zd)1ltH~hto1Pm{?m$sTGsnx%*513BhGRE7=lwn0E4+U+Rav_Ks^4MyfL7rI|64$$492_6m0^pr2?U zEXEq@ytqSqJ+#H)+{U<~Q;wNo%{SW9;?WX%0p52Ts&(rRu^1a0x|yog(e1b;<{QD& z(Dk1C>U4YIP%`4&ucL7Ep|z92u47}N(SKpl$Xss<0;*i$%UyX!F|g%&S~K?-{5)Y0 z9>0rL5ZTX~+94O_(7IJr>4BYm7ZyIlj3)^lT zwn?<9plff->tZ~Zy*(=SV^~E2gW%vsr-I>&HA6C*r%&0{b<##k`V`R`_i$_C?mQ}x zI+;}X)~pJi;L7S;R@T#GP@T=bv&!e7(4`xzmmYT;t%Om=dC1L1_c#R#3#=46U985#%!LmW191TNYC*gxn*e$f{zjzSWUaB(a^ zG*CnW_ya$0Dm|kY0UP`*y+1QfdSn%q3TDdsE zA$4wT=Vk@RgB(RTpctVHCxA>f7$6s+4u?Q1T{w{dN0JO#g(JvdvcNHnFtD>E+X5f+qBs4wkD0>Y<>SQfRa~(I6juEUza7kM zT)~VN+4|pd9*&T6wSfHSL^Kgh6|6M$AdbcWJ)vOctZGGJ3=j*$>qn>O%<)Gk40g;3zEM@rpSuDF1U2te5I7y_I~#DRJQxFkSn2^c^a6a%3o zfcF3|35W+!J=7- zbhVg(2!y&o5&>n<^dYJMmCKRY+)@W%aiG@A|3K0}5d=sWRs|GN5(oMM120L&a=a`V%kiRUElLI? z6@LUUHsu+3Zy*U-4iw=4WdxuxAVF9!fCUGHg9U~pV?}^1OU81%C=iQ+0i{<4h`8l) zNC1`tBpg71U{L(qiL_G3E5a#0C7@`O;S~P~NU6cWD@jp4>Ts}rl$0(U2m~c%21hU5 zG_r!DDVr*6aR)~)X)GH!*x)Et?ZAt^m?D9@K#M6SNIy_Yo#EgG11Z=F9YHzr0TL)$ zV^LLrRispR0Rpp_0vqn)X1=6mz~m|2fD2Sh15;EB#g8w)r-_ikGLBfl<@k2h1@H*0>m(N5_$o^_v&+V@}M|tGw z@`cGHt78l?hWqSqN;ltb9`ru^=8VjaoV2!Vf=lOX8L6f=&Ev=Be4p7C{S7=bR|loLOikTubU4JEk*gS##dw|DQlwhm#-ItsER zmeFb~P>Fxx!wgpB+cHA*Lwwyc!SBqYoF#7krS~p%tcd0z?-I-^A=UBUl7iDWPa@ej=yI&4(49)(&V)O$>js-|8nV=Qmdzj2`SQq{nAu&TzW7olCweJtMlG)u(W9fLbqqHNA{_F67{r~7r~*zR>W zxXxJy19snk=b*MasqO&%n?W|?TQ9!GwRmXjy{!Onzk+t=-iiQJJ@VX>>-RpYo3Nbn zYb-r}*JxkbSqYz^Ocy8CGsV?RVEnF@p~0Bp5AFqnw5|Aqq~0i zYZ62CIpVKA#Dzz%^BBwC%b?}*_6&%~*ORn3V_D>w%?3l+Gg34Yk5>0;eIA)X)0}h) zOl7dxHMO~jM&41icH{7Y4T_`Tx;fW_ZT7HMvcI)4-(@=3T0nJ|Q&%z4YxJ>y=Ch2z z3dLI4BN{)1BM*t|*OPQOPt1o0ww)!@?9IeZgik!zD}2si7~sM3jX1A5JJ{~PcUg+m zJ`wPd7gL93d1k6kG);=?Q+fC1=xOG=-1o|!k$re0+MN#-jd(XqANfUobL{Eq3K#n1 z10OZyzLqoHRnD)iMa#wqiFhW9;C1FBM00HA5J=#B<{xyB|EOOrYp}B z|C6^uLrd){HF8V&Ap4I4e8gt&()H|{8c8&p5eYPj3Jit-8(M$!$&q2><|mb z<5|&1pHYu7HZmSFeVxgrb?u^Pc>88@vVQHI=68hy?QR4G@ridC=f;27>PMOSwW?1W zx}Wfqjn0g5`+P&*^Sy@j?q(iX3R+rRwAib5nB?2Zwo=&&}U5tnLhi{FqC*G^U_2J9%nbC zJ803Z8FJvkdl&J>ZknymLg%}l>)wofdB!eOO;4mdg~elm$wu&chDWPK{0xJ|9RI#` zs9euw@HftHfCd9Z44}b>llW)(sbyQN(IWU6CyNdT?W&iC@Jr>a7u&in1`J)iQX=q{ z=1a~d8;ki*coL4Kf79s;`YC<3;o9f!)!9Xm$8GVhq6|5#+}2E<;kP+&ZIfc}8B6YW zN&EclyMpudiRnu7x5T&rpT;NL^iPL%y1(P4lkeYe{FU;!JNvb_=%}8V-BI z)(19s>kuJWj29;LZJoKU;}-x{*1xYvAW%U_dPT1^(@oxj-+&6L?x|$lCO|3^dMjhh zy*WmNgsD-WP1quK+2oYZR3wtYSVc29e)P84tB1pzf~M27C%_F<*BhM!9+`0A;Y@w_V*bgC1`$P~&U#LH96L^za({IE( zBakoO6EYUttX(&=y^yU%JXfWKxl|msNsCp+Fb2WCyRBtvNGikmV4{x!eJLAP#U-^3 zmrl;o@uXtlnu7=G%gxc+)ED=3xZ0hf4{#9-IG<9H&?}Uak~AfAG^VLHOg;5Ex5nWi zs=6n2YPU?(zTb+8KB#yjYUEtL{M6;D8~bEcP>mNx@d1v>G=hUXsRGTlY*KlTH&$53 zzS0{qt>(ER5@^6sn8D*PVZ<7d!kyE{*@^6FLh6_&z0ad*8!)%J`1xZWF1Ol$Oe@T! z#`$yP(=+s~w!gTxKY3@te&d2ldOi!Ion)1NRl;$}y5p0qM5Zc#i}+ZvEf%(d>7Ky> zRnHR&Q_boh^BkL=+^YNO?3HyEVjYYh1K|e$W)0x9m;L|e;1@`5zI4J>L+2|QGdtO+ zPm#65rH|>lHC^%3EgmI*FN)}5A@rZI@~hq4n8h-teOWc5LZ9!|ml$+zDb8k7gLZ0l ze@g$U#*6!FQqvh8e=o_;{n)@A!ehwJpQKoYuB#o|8TssUp0$y8q+L~ryYEfWcJarC zG}>HmCC9GLBzcGz(E2}!iM$Hq)(_b3EjNBkCFtC@9lD0_*0Eh`&Vd1%p%>hIx>UCi zy~NzKc&9MJDAu7aE~RccAAuXu#JcH|ptfx`>yPW3=?$OBKQxpQ0Y~O)@3aaBe>MV;@^( zhyW{hctbA}W77E(JhuaH;(Im@B=2`*r(^n1SabZI5Bot5`x`S@i8>Zwx$<@y4Xu|#ar~XqhbF6UAUTsCN9Zs+;fVj<5cwK=CJm>E!Npv=L1E$ zhw9#=IBiN_TeWQ)HOZBUk!#F0t`%VAIP}Vl--c;_3|gYdRfYf`R5pwg9@UC(Y-i2*06%j(dtQVAl8g&(SJkg>dY5#oGFngjuoX6ypgb)WJ zd)L%7o|;3;E_>pDYSmBD)@j#*uwHld!DriF_IW*RoT%0{t&4k2I)H!mp|`@nZ!R!* zaPa#)^$AlW|8)d>9XlKi4!8d<;95i`u;Gu+u>U%+ClKL&Bw6_Buga%^t@=EVj4v03 z^TkYPhh^}6)RFa!k+^xk`*rh4>|kV(xa)!TZ>|!s1m-x0Jym@rkYW4~cI_zEypUPDVuQGW%P z3Nh~=92FcQR^X_MK->xt55OsudT@y!xsif{!OE`Zn&i_qaB zf{IWA$j2f~iUMFF=}#0C_?;nc{6-8Z?D<<1bPcxvEB*rjx-Q@TUjU#OFoPA-SvIf! zKL9`hz>EbefyFO^R{*Sn5g|AcpiyW524TQO5hyH!TNqFugGT|N1jGQ~fWWQ)Ccu`< zpt?|41gI7d@0DQqgkR%iUD1;{BU?2ck4?vUz7mLj)YXUU~ z0ft*zvwyl;OaOsj>;hnCP%{cufO2R#Au(89uH_g|3tQ$Ar3WlPlfkqp?1SXspZ=j* z(3~j31bU+IbxAm&-oZ1BQnu6q5l@6K2S^YWWdM?cKLdr=KrSGeSP>}l0LjF1pvcOS zOhEFmGC*`uob5faG9RAOd`tA{Rg#;#Y)anOKe?rhs*aQDP{) zB2dj(MHf@RW3>b zYViTufGLMffLwx)Fqd&nuu%gcqg0{n%S#aI;ot$D-0gQ9e|Ft;u?;UKV zI#dQ|`qHS%K)e=v{9SdW=`4p@>6!cIi&2Dv92gj30eolBE{Tvs|L1BFK$HBgi208d zfpYm>4f`LfAtiuE4^)I=S=kkP^a`W$_gq$3Nxw5!B83{zySm6k4e)X-@lykEXlWTW zfRbL#OiGL86<;(imOrlj?O0WIa&ZLa#y?(DIPjc;w-5RYbo~+*13db`H29aVEF7;( zLnfP|SE9KDlD<9Kq$~WMa;&iS=ibYDDa=`^a;Eze_S5h-xpztKm_4TUu-Lfkdzbdq z&mr#Y%+JTWsE)j6J}&q6c$(NW|ACl|9tWssY8j#&3igW3n+N(d_0zyJPuVb+KF|V|(eFcw0{GG#x}hNqjsVUMPGko= zbGREuf+&F&gWHnH&aP4jgyTO&5>75QA}|EFO<{T30(`i>vX!%?HQe0H!XAvW)B~tV z>E6zXqzHZSFRUnqMxt>@3=)SzBY{U#6e%f$L<)(l*iYSEtgKeybPy;xA z|EvdQ25DaC_YZwYG?>HcJR}+kbn@zYYiY)zfgk(NdMF&oT$4utmfos*cqC*f{5ggs z268Q~%0of!&p-1BcyM=RRUQ$t3s>a<@>Y+FL=%D4x2hfnn51j+(2#!>`u*k=2D4@! zKvEe#2*8!Ux*h?#^t7ri0lG5qXC4XzU9(u7hg-ud6!;%4*3`oifWiG| zTO<(=nTycx_dJM_;2`#A9Z4u$IcTN#=sQRc5;F?Rf-v` hMzRKWFrai1i*O~IxsWN_7z*4}K*9tBl#VFF{tL=E&+7mH literal 0 HcmV?d00001 diff --git a/tools/figures/output/base++_InceptionScore.pdf b/tools/figures/output/base++_InceptionScore.pdf new file mode 100644 index 0000000000000000000000000000000000000000..03eceb58b13caccde03426330b6519499c8bfa98 GIT binary patch literal 17533 zcmd^n2{={V|8J7$7|K-1(2>l;nI}c&F`0)9Wgd!SI;12?M22LHk|=YM*awfEX<_^$8z?zKMa^IiKD)KOBAL`h*`g1J5L zoCh!z90_-|J_6gj7mhGE?q&x^C|XghoLwE@2pua2J5M+oWYB}l%EIi(wopdmvH~iu zWC|QZ&44h`wmxEKLxB?(Zk2r~s`?ZwiX9xgkf38lq1btl;W+3PhR}DkvUMUm!10Su zJzQ<{?I>_F(5#XQ5XH`i0!OI301*@xW5vZ-9d5Bq9MtL#KypLUd)NW#7Q{#B*?GEp zdDwvdK=On3wd`!2tQ1^*K#NH54~xV~kx)oH8VARqkWwTJhDgAHh7gLNi=ZA)IB7vp zc{14*+@YXC%hdusUee=!JF)}C5sq4HTgAy4^c;>*aR%y8va@luwSzj~N%63=a)J3= zd20H|gNMLUK5abNDtmU99NX>--%LMtJC8joxarY3%8%16(@^g>@yPS{-|1_gWSNth zDO?~Br}>SE=}e{Xx=oai@19Q0lkqS2mHEA%AM-n=Tj9bdJ>p;8e6IIf#oIwm?Vo0I zCx;{+zV_-p5z_#pKUv;>nD<~zvL;H7{3?Hb?7`vEFk|G`({`?Vaxwg$pPqX4J%G~ zQa1Cg@Av6_ZnN7?Mmvj!MKsXHZM_c{D)0=r92ILY+gbj(cPiy};VbzGs|w%sZ`Nh0 z9pIh}vD>UWz2l1CuATRWImTs5#O(zc;Ey55K_-AgVwlV0yk z!PPTgUD`BG;-fOPq-eR$Y>M~W4Zmni)1aMlaevq8=3({p+QIcDpk$PDA#0ir);M)$ za=4x~Od)c&OFbxe9+bN-(ixQpcX!D!lzWh=t@^l?sNPWb?VZ6p&j52P4^}l*R)St6 zQA6~LG(1ICz+zXw&3@8dUGXnna#!&sOt!nt)kWWuRk!A^e}11UTiZ-q<;I)D2=2)9 zft;so-;W3?=>#4|eDZ(T)$6=(l8}=B+5gwcukN$sFT(u4ZBEJe_n#SluODMH{govp zdbd+)o!BAOF4n--&TPsrI9~MAnP)-=HI%L?=3>nu$8{M=qoR9g!!ZerYEO;#R(_Nlk37IAY%G23aObb*j>j~+ef1c*Y&guX z7ImoMkNvEzaGw9rpQVoHvMimgmPy;PiDW~Y@nVEV=FT);>2j-fioxT#Pk!Ukv}^~m z{W+^jtZ6KXC(OqieoWMs$~cCdb&#s|V!0)FrSXMPzzg;|t7eAxO5`e^{X@tzjcHtm z-@mT+Y5x>obO9|JtZ?8Ri=*kA^pl(p*N<)IG9Eg&UIC7^c$EHPWYm3UaMUG1@v;1> z2U_$#Pf(+wyE2&eoDTnb>|8tgW5Y~rkX@B@K?4c>UG(E=SfF;srjJ9REz%Je?7egI zC&;$tM?lveAM88FcBa@b?u=(!{g0S?zMsp5EX&e*XDcZ~BZ>esrour;l;*^CHm7 z4QeZ@PCFIfiumq$y=!cPPrNEwUGBE>yi<3JimPkID?Yx|c*5s+U^bzHcZ=8Tv+(iO z&x1%97aeBO%JAf?xB#4>tx1~)TTz*>9g__z@TtmBO9OM#;M7#ItV= z%?tM=cz2)6l8Gz1WQ#wUsHnlsi)v>PnY&vkV2sPcvisd^-E3%it}mmDIqJw(rM{sN zScRGHW4V3DkGqvjw)~~_$x`CkR-y_XUe`Wc)5~kP8QxJt=E+jlNsdd*vgKQsNUO=s zYXt;{6%`6t;BF%6<#g8N=1NL5%nGW8X-eJq$SwZW*>e4Sqi^;0iC53187^1dta9Dk z=BRDX_d2KR*p{0TlAno>ijI=SijL|c{j|wq+Crv9gr{2XOUjt!G_X~E^;+`G)tl`P zwQ|01J!j)bmJ%xxZJ_%ow2gBI=l%n`o+`iJj1t{CVKtp}?7FFzw%x4^k3ui4*^!JU zy4{|sjXY}I9~d%w56kDD%)_=v_4L=j-v0dwx@tW^f@rEV?D6peXX{RGenVDN(yvOi z%a5DIqPr!3)di+~3y$`hC|&`wStycvFrkO-Y}NmRNG#S^mvXZ z7%-QKiqtjn>bt9$aR$ywK6BjsBtIf!q5!1!G^9%CySqndqM~oo<7nfKXc^iSWq9?= zNd}DMHU$ZO-BQnE-aW#Q`%UiRwX(dk?YA9SX!ICzC8Q+Ofd(7HM&HN3`f=jr&z*Og z&8|Zra1pmG!a*v?fq)MhjYchDL@GLh63yTiOKIrkKy48r;gA3o{Q(_e@B}@i&g~7X zDqADU`-3Yku0O`KTvE3+7HG~ThUn8+-+T0}I2!&YE3BKFTM zAK{JxqC^_PIh*g|N3y)iA1sbmroaQN&euvy7;gZ&QWmj{1 z(Tjm(jnq@RdO8F&7IX1Mth-c6QUf9fQNXOnqpUI^tT=IgQ)sdUd)P5g{ls-}*fl(-y);`hn674~ux(qc5TDXg?HA^NxI>T*8+MYmh=5-u5In>t>1KKsM z5Cen4|Jh7rqBSE2kW7lxq#&u5(^G~0Uk$FK!jfiP2U`aK?(urDH-?X}`K4wI*JK`D z@ca6}d~r`%u5U!!38VaP98@b z8a-N@#fDsT-R>hMbB{k`JxG5zq1xzTrdd+m(LD)C_sN5kn=Dy`caWOeuBf!C*c??U z6lvJF!?)6n_oRWyi5d0zc}GnAFYA-NjIWbY_Dg+=cFV4cDc$U0#*xg_om1Q{>ELCS zKT9bM9I|vHiwZf4^(H8myk~oHd89JoQ@>yrxyNF5j^~}^SwgS4*p(ySye61SdUA5@ zuk6$alZ`BM8ywtud<*6D8iN4sRaXoG8bkQ=AmDU0$#f`&SA!f6D`RGaJ+B(ec(wNW zn<*MN?cQq{&i|db;jK@$KwgyK%yV1|T}9xW^IR^6Vu-*udB`!j@?H51$6&r>zNgtb%9phb4U_z=uE`)Vn0PEZ59taA6DzBW3K5#r?mZ0d&G{pZ7kU(k@|N8TyB?Nyhrg58j15R zEqIv0_t|Zlf84{litu#=R5T5p6{pWK|Z}@IbSkZK`Ol{f}Xs^)Z#w@d*HAqxErnu z@tp&>Q^E8{%LacUgFi5}L_Rux>AqKr?Ue#Ui>seoeVEyuGA2K~KOb+QanK?wdhUu= zyJAtVrystz=hsh0*UM&NYv^z}>9V+X{?#g>9IZuujbu7IElY>6l$*O~n=Gu{qI_cB zr6M(EtB((l=0H&fZRm{(WSGZ(Mi`8e)wgf`-d zP1r#HdG`?6M(f1c+kQlgHx#Ft;~ejF(eStL@4Xuy9r29Q#wrUTb{Q+Nq4?FChsW$b zux?LEE5g=2WqH^w2)LB{LB zr^ec4&!$)HXGynkWAP zI3b&!Z`5V7mgl(4m897Ah-*s;(=F_}-LYHiR!f9|_^B0d9g6f9a7#H#Gny8~pg1iX z6nR>H#`EfdbbD0EJRbdBax_zcj6;@I-_AO^;bOJh$Fcg5%ImRjD^M>J8t#ub-G_N1 zGiB2*g;2~r)}xpQRC2d)mNtfNN)mf0cehfO(&4^;)2ZHrB`DP|M_6YKhdrd(=ZMFR1Fyn%?#`^qp65`_j(a1!DUzb9W03OWYwUBSqFm;S)gl zsaZK{4>UC^`lg)OW>Wd1Tef*a7)RIaeO4U%ja@YA1BD5@h3@sId-5m1J>F>s*FBP+ zDH?v;JY#lr+kxugim%T1nNP7eckJi1Prv8J+!(zr-E|@&%j7|Q(j|U}s+{>9 z!*MyS&Re@~YLY~SeFM37WxArdjQFiBKjg$pwjI5l5XqD8h&0+_!2Gm)ylieqU$s$X z*auAR%y98YdwOZ_F}2i#QsvFL#Ow#19S6eh;cYeh)MoTWk}PjVWu6!i;&4&5C;R+t z@ay;WKA+N+o%dStu{}BJEP=qSuh!U>;on<*zucdOeCp5|upP4oY)9e#!mea4>|v<} zck|xR&EUb;+%Gbgr%sc9;EL7SSqr6^^PB@R ztHkfCE1W0rO9fqUxHF|CS|Ax>r%O>*BniK&@zHpX0h+;b-_ zYvZ(%sk^z)0I_0{vriqNbUBD@?e^_@SI)Y`&eKl~f+96Yd7tJ_8T;NHq#tuZwyk0N zYanVAj`(L-qO2p&M$6AoIt#mcI(q2G)iKJ6&R_H0mG@o*3Z30AS5V!5z zjdQeMB7KHGQUYotZ`oaEyX2{`y{REa<>tp}{nLlCT;g_`y3y9;S=X=dCopF> zT|=VllUVHWY|Xlh=K2pQH*Vh0{yyCFIleh$jO_^9LqUn%LJoxE!TAEL`aLn*-(jR> ztG|slS-1MuRx-W3DI)X1I_CVd$W!HFaV{qk4eRS94*#sn-t9UOqSzsY@r~q`uv({i zyXQ@{t#EU;b><1STT^#zlRiau7bm9*xu_os?N@vA5F1cx{L037PnE#e&l_X6VL0yP z-?+B-4Kb7FV0!4Dw5snIa)@=*`aJva-n4EZlqsuAs*>7zW7dNs$t5YD#?8wP8d`B3 z5j`bW)fpLbU9T%LBkx>fZ4bVqtI$fd%f?e6Yj;sW*m*-QJ$)X~sAslQfQE6S(1Bn6 zRSqXFOd~pSPVHD@h}HnesJ}2o48A}IdAgcI5r6a*zPuOvj_XOndChS}?k?FeP0<}0$!ICV&{MbSs|@P)vzeQlcw4I1(r>>e zaW;;xzWXij`KgY=p)`cwuY+*({?%$iEyq|3o=U3w5usq^X}55tY^1IzuPv@*cf`ub zOW%s@DXHM1 zQ>e~Rs38mtalpz2xD@|m{h$(s1z(~h3Q0i1C9z;{jv^7j9~9(9gi+hkvZ8o6Q5}v_ zNF)SIz`uX4AXg)hPzU%yfzOrdU4$W&$-r3%c}ahJXP1&-2qh@!%nW8gOU~H3NnS+}tSO-v%)TkAn*Q zrvO|23=2961*nmwfCF`a#*lCV*ya%MXkc1c@DOMO*cB9rRH+Cs3?vK)1YmV|z?G0l zIFOG3Cj#o108kV)25$@jm^_{cga>>wR3|Vx47fmn2cQ8o4hHoGk{XI3DR7{1DBwZG z1n{=uNFYCm@dTh2%t9IgR0eI6p@0WvK@3!c1ucLFcp`yX04jt9g>gVKEGSHUhC>4@ z!ou-bAT_wg0m-P9fl8nl2(U12kP3kiRMmn;@jy7LT9%}PVyX#2{lkKiC={MREdu(q zWIizJhyaa3#i;>9qW0~dsunIF0ii08Bp?~oeMl-G<&q^p7C;?1VDV54scT7)g&Ocs zm+(~mL6-1O{ZKAQso2E;?qLL|77!repy$vt++qb_aG=sle;{k18UkbtD*_5h#DTg% z!;6-&6farEQoLYV3zh*{#lHe*n|hBz!ypS;3RL3&&j?^~fDB=+00tZg4h9&qjAa3~ zWEo5GfdtOFcF2SSgrHpiAo6aHA*|X0ai+&2f zM3kZq?F`}|F|0-%*fq96#J|`fYdtMxu7opfXt^r7cm11rS>4!OfIeVmO)W^F2q#;f zgOS}$n2PrBLGdzjrVtO2S_sM!2| z!Mtm?0xH53zL4L246K6ZDC)tVo!A^g8O^|7WS8c>g4t(9u!*cxMs2(@Ov#X;7b4A8 z!OnVe+l?Bj35VX?6*+>56um0uo8q%oVyqr-CLAp@0Dr5e7ut}X68mJI%ad4`w0-La z3AU0qw(MHO#!_{?aIRSN2w#TxF2CHp_1C3`Y_*={3M6*D(0-LT!{c`UaJTJ&xW^Q} zrkN{6&jLxD52pJ=@$^~|-9w|(>`r0`y;5G&zwJX zH8i-^)(Fh8f7YO$Cb?D)qVvtKJs-c^d|c1)b=j=zQANFMv9eH99n$YfRLOfy3pV$l zhWm$ZneELuDP?Awcs!&uEJ9HBUQfa&{(BY_sRAOe>I5S?SxhPG>CM*E8HM;D`ltir zsf<;Aq+k7r(~n->W-fj!gOYVaaw)(xGMgH?UHpFU2b>5jOD zr!(5{Ol~ToQ*cqMSwAATPH8mOATKJ?VK;jP=W7RR9?QYjJ2bbr43y%1M<0h|Kg$X) zQ>u~G*7_k5zhBa*j;zmhcrG@)?IeXxEE_u!JJDiT*urQMdYtVWX-;isu)~=@Seo21 z5&C{7rWVci%u;d*VeOc%uBYAc)1zz)HuLDNB9mR0?^-KG zoOkw=Dc{hs|8p&=qnf?S`Eznd>KzP5jQ>pcgWw|{<-7|Yes#N-vJW$q=Ca0^F5xo> zFfS72fAv;q>8hWnMKYBRa{iFxKNM4zshFrZIREkqOAu+pru^ub_)8{?n^9+XO32Jy z7TWtm87NN7)V*lz$;fTu=0TNWU|v zcmpPt#>P(_U6S=7j-UMilNS$*>5Z!5qt?~$Ok4Lx^pIy&?CpfkGGDM&=-GZc1paaT zo$V5_csx7$;4|7W<_2a{%a_^Qx?yL;V>>oc(u`_uHoYnQ(BVx`l$>~z<@e=>qfvro zP^;#YiO=C6*-P1p-k+isPQ2CHv#W^@b_2afQv9xO%?LT*$G%3r@t+OTGrakmpU#=^;x_~T2V9Px;?63(z;-avj64A!kJ6m#y|e8d*?Ny|Z&PiN^E?;Tg!Nf!qW zGnmk8*!1XJkSl-dA=%JFx7khDzq`fYO8kpsP8Zb;MSE_r9iL}$5Q@q=-fELP&1f?# zuy+k9mv07_j$8hXUluTzfV~0C<$+Xz839_^W_z?Ke%jTh(^;?bg$X>cl>JOw_nFY4 zGw1FJzNY(}x52??ZWvF-u?=kSI&E~L-!WFtucX!~j`GD3?svIVf~tPIW(Z*2`|Ib5&fR;_&sS49yI*A@wg}VReFa!()p`r<-u^VvD~E$62+5%2n_;=(l+ZV@1XcA` zFmDkg7Yo0ZG3VWsC`!guD{j2XBoS=k9xxe?WHeXRiA)~7Zq@#9WJAPMhThlDn`5tz z=DrL5NV8sHhy5C&twFB?F#ku`cOxcL{or-0Ao}9JOG)gsYe=&Oos9vw`;X$_z=RI@XPZTf<mu&IVa`u5}XW%fz!hC`NBeCI^NjTsBG_?#!q*yC>S=Jj)RA$uE< z`qrs$Z_~AXu(mt%>3u&gzba%*H^!pc?Nj{IV+^g1zqq$Pd1J#FeOfiMfQ{LSSQ%1z z^$@Z4&{uX6OJ#^na*_m-jiXTJiOA5(maB#7R&|g0OsBqXHW)s6ZjFIh10iA{TH;S8 zi_=?j`TD_6lV5%Af~#Get6=*8|aHSlgc7i4gEl=8hOuA7Z8aLg{KW>Z5B z+n8RkT2`46fBWY|bpCyu!-jgj^s0dy1MUrH_Eo26GCux(uOR<@J?{lR6Hb9trAl;d z&CrhcXP<7{o1Kbxs=VNH_KJ9iwGh z2)}RJ4NTyzV?63^;h{PgPkRS+t1*#$CA@WaPGUq*>_gq$%02P{g3*^qwGSwheKRcu z*M4<|6Dp*|hRsY5+B}!#9*eD$Ko6fVnWTK}cpS)x{hT&8*v`a;wTp<~UaOVNhrK=+2!If9YZ}AZGwqx6#3+ zdVe`F$=B(AX;V{7$89G2+|6_0qCG>kZ&6$h_g~tzZ5g%5mr0aw$ThDKWZ$;G-AcfL zWnUs%s>oA@03TE_NwA6{8|}cbMPh^0_=9dJX`oJl?c2C$80*{0wvfmA`3O@ zDd}>!S&MkIWz-~hVjz~!!d*&u8zPryatcqoP1h-RLQbvnCwcRf=be~7AI-sMTVM41 zK5dw&GO(<@{E{q(Z-3WU7ScZ(of1{>->8@Mis1uQ?K)11HL zc84Q;w9`bM{;GHy-fG16$UL|xmOpVqFD8rsy}s;;M5!wedR{gi!4Ae3NqU|xzE_6G zi>?f{4+(tabg7soX)lf9%hoJi#cg~oZRz?Vw|h3op1)7t{MzPjK*Bx#!BDuh2$zKX z(`;q-tuD>NH*gh)!nk`A#$MO!r`N@A57dR;UQ zsCIAxJh~o__*l z7tp9e@jrwsbuwP^G${e5dSP`?hEv%)z;gnp*doulFrxwvoVtEcXI$(e)k&R%fjX%3 zIvMx`7VRFhU0a&7zhGuGDe$?xQi5#2&S-1mpI~8Yv6@ohD16L#TVA*^I z?w2H10i6C~V=BPd7Ha>Ubva&>M=f^+{_}|?LQx*nhp@2%8zfK=A&>sg*`$Cw>~}`Y zf6NG$n%~*5|1ld<3am~bBNWTZF6-eki1fQJ%P6u7Fx%Nt+96-O|^nzdgAst{yG`fc)d8g#%YDco(7n!0NxmAy7SyfA}K7 zC29`@;*@+-tfi0)9hWS+V@p($MD#xO1smRA%}JNH+;?>!-OfgzZsPVCQ}u^;&AY#M z>rMU~;?2$ebf}v~`z`As`PYXsB&G!964xJ>qoJ!|OmMy@pV!g2Rzu#bHj}Z&2@eDn`dbbHmEkQMqE@){SQ~cGK*XMl=*+06)bzd{)ehy zIR~^b6)b$lj?i?nh0aNU^$VJbsb6+`f%5^>33KsVCV-WG&&rD6>ul&qhzQsexH-F0 zoSm%U-WVy86j}oANTIlSN+S?1|2&d%^>7e{A;9qsTQ3{%vG?*)ZnpMtYb%?hpq0fM zKu&7)POfAn=o5JnC22GgjYDFPI20NQ{Ho$eqA(IEEV^ty_3^N?hXMFOfC0bz;t#+( z99RtP;s2yTb0hSFlNZxacqDKIuSmlZAjjgr^I;(vxiXCet!OLqLGWT#8h{o5P6Kzq z{k$>_1Fgu=?;rV)Xgt6~E7Fi?Bv^-6=7V;9|4KvQkN|_POv6Bn>56LBU!8teLs(fe?fY~d`VsWeLLz2({!u~rS(8?;i z0&f;L+Ojeq36Ki^P6I~}z!901X@pgJ0Vg{kqzV2086y-5waWfbpu?;B0o1Z`d_hrg z1Z_oqXf)X3txSU;{E9Sy+?SQ5cvv|(+j&s`HzE2?CjcY{$Hw$rU7-n-3VGGZ_5kyO b5(_ZblVasTp-x#SJO+h<2?{D}tHAyj%yefx literal 0 HcmV?d00001 diff --git a/tools/figures/output/base++_Precision.pdf b/tools/figures/output/base++_Precision.pdf new file mode 100644 index 0000000000000000000000000000000000000000..38f410a83e3489a1c2f383caaa1a69f0a0672881 GIT binary patch literal 17753 zcmd^n2{={V|8L583}vcRbR<)zGY^W)LsI5hnTK$QLnK8-<}zj|GAAUFAwq~mNr_0L z3?+n!qTIC)dcFAl>fYzx=ef`GKfTM^d+jxR*LQvQTA%g#uKn_BC@6^`#W67c+-`Wz zeHapsfIC~9fE_phm(cNYv4%^?o0H8Qoo(R~8s@gvZg3RHpaqwfhFO!Wpp1kS1(cjg zWH_3VLBimu#R+RmG915ntLRBq)+U>it>Ku(1PyaC+1ix^$3nL-32i%bD|?bH9Jln; z)!9AcEXdEWZ@1!cA9*gIwDINNz}aS8E{MqWBV8)^5)3 zu9l!bko=&1b!#hob2(>E&>{l-!yqx@7%UzOMWdc3T~!`38QvK<_`)UuMjBj`6=Ldg**L&4h8*~%Ic-;L~QZSDl~y!O;c z&AD5Nt@5k>xZUm-!?$jq-**A@>~9#nQ>rgyq%On*i{=& zE#H1eM^w*e+NR}E;^&38Be&0;nY+L6D{CzE*eH+W@IrOVT+g?2?}prt{xqJi_#%2+ znk;8gr+hi&&FQ|!2_q&o4aNxE)rC*@d&@)gw@mmZ_IO+`2{0)WPN>$2gSRWyR`T}j zOLZ4GGCBALd(!i9!!LuPudQ2VcAe*)udEt7H}>U6h`vc%ibd+g&!Ms757414_NC%g z3(VH%YO}704VYQ7y?efgCSKfV@O513O{3betTEFC)}tQY@Se*@kNR%E!|}x4Z=c12{c4y)np>K6#(C?y7^B&$p7LPkDB-Q;Iv{?hS85+*_<0!pOu%PpW ziR`pi>+`X87jNG)jA#j%0A;tcbE!&eed;H%hrUosyL~%}DN3Zo_QM$si&>TCS87+P zX&)WyUO4S`%5A2y>O0wI;d55v!7n^#y(VUU^}U=vxBINu_af=UB!Ny(wnOO+HHW$f z^Ayb=QlAp~TzW!AV^=7r^BzUp3www3x~X^m{OT)CdmBxg z1uj;P6wr)IlKZ@5lsUr8RIHd+X5fxMQ+*ffF;6(&ATG^5q7i0t$8g7p&HMP*_U7y! zM(sO1*jP30@>{+{`-yvfa6BT7f2}(l{|@_JU}}fG!+46K*ZXVKXuV?%5(hWWKe^T{ z{pj_(teZo99l<*%95V9WHS}xp2pezX&eEN)Vc)D$_>c$xX1nW&@tijFghQSBEcQkF zxV-JwyFpDixH=}%Y|kIBK_3YCu(j%t6nm9Y_r>FUPRs|3Y=_)9Si);1&vqY)ReXL* zac@wN*UcST(q84##=pM)I(_P^FLAQzEwZ}i=jKhKk9gY&;vuT5QfWDMtTpY?=0dK4WR&$xO4aP8(cXEY+aCQpwCO(&66k_r&y2?w*1a5Y2@7j> zOu4lGX;Oi>WTL{JA1(vGU}J+Yg`j33p0`g9m2JWADIOkUQdDa>IF8%e;p-mK?N^e9 z7xvm+>sP|J;pE4&U7d;K#OtPIpiSlLQajFXJxrpGZ{PLI;PbZ&shpBeB}#*PXu8>7 z(oP8NpEHs>a`GI_e%3&<*eoGae82^y>z$x4@X8j5_N+91D0!f3KBJXN4}<=cj2+L3h^Z1 zR0E}n`K49oJM2$nvAm`~Q2#7VhdBs)W=~NmHecDwpg4Z>neZg*ri8K0XU0zQ(f0Bs zH<<0c^e*k&cYy($Wc@Kx*qmna(O82+yd#<(4$={m#ru>Gs{|IloUe(S&D*KvbHOuo z3&(Q;nmWF>+_&5+9q-^$bB;@Cke3C z3VL|H&r6l$h})if<9J8I6?x1TFWWEs}{Q;|Fmt)fNj(mwj15FGpNm0;xl#y$8L zbD7iQ!p31Q&ICorQv*yh?x#!$#Yaq({Vi{frv1`6rV=KuP$spfHfZ20^83!CK6$v* zXlhbV+6m#iB8J9KF8O|B_A@7JdHnd8aF(NjPz_6q(z8;$f?{z|aEJNW9>~=IsU>lX`AGwC-cC-%G=F$bPH5oEM=z@7u|H z`4&j_5NcVhn>Z$y1)F^1s7MTRDu#i3SVJk^m%r$hXRBm6lfP-5I)mPJ5$ul3LK*iK zY`)qtFx*mM;DLa%S+vAA=-yA! zXf;9R2*-4%CkmYVP*&cmR21n((7PLKvIQBdywg3-R;TPDt(w!_0kZ>X`{=Ig3cWDM z)K*IzC83!k0{BG+1X{P)MYvuRmjv;4%O_jf1dKY}kivluY)|*SJaI<%MAV6dvVgp; zCKoDn(=spWoobA3D$j+;s;`_ zj~8#s77b~ft>xHhuYUGvk#j{a=5(oH3GtUwLa=u8cwj_-fF|&0gJR0psp&vc$+m4&p$jm-#ukqt)1rQ~Y(&dtvII4@D^!zU5Q2 zebJ0j&WOb-yUcYHyY7XnwZEn*9A%T%O_#eXT`BwE)-5Rv7sut?@TgZiwdv-~vCc0B zt{;fSA1_wCRWV^NdD>YkD>8AUwcoiy6xT(_>)&-UCf9;p;WF(|;#fb8n)p?zQdPb$ zVb!DEXt2fs4$(45_wUHHyKEL=@cOT*^r z7M;{1f=?wcWp9A68=EFvhJ1V#Pbn1?tZ|+huat( zbtE(H*{WEPkCv3QBd!(8dHK5@%61(quou}sJu-ig<$%k%8al4&wVgmg$ZItk!(sm{ zoK!fGO`exY-up-&pD+I1frC-1vSYnUlSf||Z0=WldHBAgqH0^VF;AM$r?(Z}t@f3L zLat9n)Laj%j=wU{P8vf47z# z)ji@UFC8iNt7}rCtmoXAn&Z(46BX9iql{J+urPNAet!QNS!J_ubc8Bhel>|Spx%>m zY=PtO)cl4`+oaRi=`Do6SM?T${&Q#1g$ON#27<-{?dC?~#;4}bAuq0t;JeBw5WUGJ z-Y9O9p+J46is&Y|la&Qa7B4E(Wk>4%3nI)LIBs#Mt|JBn^H+(1M&kY~hg76mct3(k zeun5T{@i!E@WWS~G-OD^obym?{IH^oZ2Kh+wEmkrhvp`WdS0GpnfoSUY1~#+r9VQ@ zY`*(gaA(1YyJGj&o-=RGr>8U!?+mWn>da-FNn3nj;XWS#^t45P+YsvTe z&sYuVj$f_Si^wugsCU?ZHKBwwG{s=XBDjOt+;&Z=RmsvpsZgkqc87Ph3-=ivq0_Uf z3k!DW*k2ZBdKlj(Bp((Zi*(7Zi7MahYRqw+sVk@Sm6)x&asC{+{K7{w7m~1m-M*fy z@?`^TFB6BWua14-?<94b&du+7FLn{%BeL(>iEr+c%w^pQ4 z?7Pk&K%1*ogMdQg|2zm-O*Ik?lK#yQM`d-?te{(xzLa}w&pBgx9s9ip%tCp;6R7(< zv-$F_@Xx-$KBuX=Fz+~@%OM}g_f6Jr4{FPxXa0D*hF<*i-s1w5S5L)0CI}ujQ@qz; zfIK)3i;Hl^caiSNX4(btT<3Y3t)ZBBR9E*pk^7U(w;L{~`kmhSC6*eKQ9YQk$b`)z z7lK0`xoMcFdD19sKlUnY$NV;y?6Pp}yL?V}9z+zA&-o9>c$61ZX7YS-nc*FGbu1fu z9M0Z+%3BhjWs>}Wj@TQ^J~NNXIf?B-7`gu(M`VpNoDbF`VZpTyU1RkIYk1K`FMD;RL$J+2p$v`s=QS zrrzK9<&yn7o?DaVRGlo-3=9`+9 zG%`q5(aYLy{e@LE-v7ph{wCFOe*2yIoA&378e8$1)@idIc@g@Y6nhE#=9+$%?A3cL zp+9wH+lt&2U-|c4oM?Xj?2+B=JI4;h#1#sjXRn-Yxr@FV{8VdvcR%)Q0G&g{&`(6b zAk*{khhA|d?#Wiy3Up1AzO;HWv)gA*4GsjynjSJVy%{-wP5qU8QIDGsuBZFgPe$iN z<9+Ms5Zd;y(jgl27t@epq&n#>g6ZOnG>wFr%zT8^bwS1Fil-Nxs#2o1dV20s>uITTllALs4QIdWx^=3~xpCSy<<%E;&&MjhZiu-Tc2wetWk~;r;FE#UO&0NU zcYFw@@5uJEUL5Z=QSiQj;`g6MhTUS0ViW}=_UJ3HA$e6>K8;!rvTYl*xgl$Mk3KiY zSte>eKgWCS@*uwSRgVyOU#DM>ZciI?8K66K`4o=$jq#(~5Mx?^Z?hJc)w|vLGWdKp z#TSXCNo^TsiQ^mdP`z~z%;C~^cQJjIPUy4KKiM<;%tR>pW!a-il1-p#0TOm^=g)&f zr)IX#-ao6IGhvs~c;_PKwd<_vlc+o3ttOL&EJq?&f_&RUt}SIuw=o;`MsHmQYM|EROo)g-zjcaN)FL+_>E&mn z{lk4_XWf#HWY{3f7I3KVVk22{BrKx5_RdWQ>WEsG&!dk6tJ9+Us*oL58%xHUOJHt@ zEa}v^K(dMJMkI5;QtlSc@}?k$gngAV_o}7I?I#a2ob53zLn=?4V4c(bv;J=Dtjg+yR`#@hNz@fFAfU%7cUN0ql|eIn0%!Cx6+#p z7dj^~p2u`ng$W5h@;i{2sxov8^+=rup?WIl+GO3GR1$BpIXdCfKAk|NKvRL4=aTnw zcZ#)Lxc?(_AydOo%9}Q&LU`LvMpG*KnJr(;jT60mbV~d1AN)Kf+B}>_q^!CR(?q!p zTBR|*d=c|$!Dnn@jFkJGi&SnHEg71X(RJ)r-M*h*Y)uzz1-U$~M{j(dq*@=!B0kk` z>B?9j+29xu!yz-u)c=ZhdUi@=Hs)^O#P9~4WR+dEDqYB_Dq4{Tm3sno%+}G)IxrK3 z`HQJp@ydQ^YL@p-KEKVN`bU>^3v~!b=UfRZmOXtBm1=+C)x846A2Qr{ufkp5s|D0Q zl$Q#pfqH||Il@7nkDDOF?l42Pj2K18$s*V0@4=L;8U&Re8l)D9N1X z=$BO^QlcssjOP{izifMVT3xt6GUlDZOyWS8n2Nj8U|;aU5bnXJ`HgpF5lQRV#yY?k z`B%&KQ6kGmTV$W?$5}VZ}zq?alUyvzj>CPy+^AKaQ{{_DN z79{-~b(m1E?vG@@N8z`v)7awNvcCK+-D! z>Dq52%@(cRkE)qEZV5>ZT0{ju3qSi{UyRe~c-_bKqQ`&MXYX~M43uvdM|+3w6gA(V zcBlJYt(9O)wnf%yw%gNptrEtXt_t8qoO=mGpb2{kepc_6lKYiiXq-DIDkBX62;K;9Y zHMVCi&q%cAoZYd`5Um4{k$+)`Xk38?!nc+~9)JKIu1)JeZ?HF6v-%;y)b&RmLGiLI zi##tF{A=g==q*~mzQ&loo^#&+c%klkb-*h~@qoc36iL9L1^My*@gFWH!6X#yZEUQ87Y=f} z8N;Drb#pPd1RgoumUNg59e*oJfRA zSOT{jSoy7CNCa5MY{4DKY!BoIxk+#dXHXHi0fm6G%?&O=2Au+aHVJ)uE3zGM*)94d z|M6o|6uAt%{u!+ms-~#q4|;Vpw{-(OTU6hF)i7K_*3A;~ND=|^m4KRoZp2V%FxAPM zyQo-G6aa?bVEGdGJAo(G-cpui>u3#zLPDEt?W6}?sGECIx&=pKfG$_`Y58?BhT{;3 z{~H1Je_9V3Lx5wjI5-wCg#1K)?gT z!-2_&NW_ACcsKzI%A$c_Pz+udJg|2h0SFHWQm9T~bZBsa1P?$1C@c)>4J0)bLsDQt z<50kXit*r8!xBM$5aaMbE$GEGJg5xXBSQfP%7PfE2m@LG4{!uLr2te20}5k-WEfDG z@(hatR)m4$FhFW>jRlfXDg%{3F%V#J+#nSKAtBz#Y1w>Wlo0_Mg^E)Gnn>x}@2VCrAOWE&5JVsu)O|=QAmy?pKo&q5IAHNm45@2b zki{BsP?vBN{Xv%SyM8Daq*TmO0QWFFR0{}@aL{w;8Fr}xFgQ@@whhb+yL_+)+mklFV}wVJ zu0V)PvNmFjH#p#MZGUrG^U&ENZ%#>V&&zDf#k+QPNJ%!eX&f_}yZFqm=%-+ZL^<-< zZhsCU{aT;_=8H9*CH%#vSp8`^b2Xfa`gxM@fsOALq&1`W0LFlsHKib#EO_1Q5{z`x zfT?KzuDsbc?Kq#0JZzB~Zf_L~(EME&(~tjXLe$0?JTA`1>)&>*+}gb>WF5iQ*|MN8 zV7dE~TS!V{f4BlK(`TZ;#Q5GxJ{KOP95I_Oy|*_d2x}ZZTa4cOoYha62^pRLfj{rY zZNI7zxe3zy&lhUoIr3WYXQwv@l1DOe5$uxO*U$&e@s{D$ib&dPpU9a~bOOY=8rVg* zoz^3IDuO|`_Jj|k!-bRfc_({r6&-zyGZu`L>WBAfX$3WABu76v==3BSCTY`3Ey`B* z&Wc@~&{VFf70MNj8s^FL*yEFX;BlJdM=SMbxqR`RFOR;7pWWqBa=gpxNK6Bnr+M~T z(X$Ig&igYTf^c-|VO<|bX4vibNobXG8(lx^&KC1_`_Oi_tG4I3daV||)BU<+v}YX+ zuD2TkGwYuAtO)b9Ru|%HCHIgseySJh@Ufpk~T;(!a6f z*lptjnPrZ#($ao54 zjSunH2MOPY9e0>Z`=n5^clHhn%hi)Kxf0l9SpjvCnEmW#XoW(Z z^ilO6La~R%^y*34T*v35L)*@fY4&AfCZi{x>lQv|Gzjuy`$n8snH_3(_$IAu{|?WBbudL{h<8rjiEd1ZSJztXJo(KG3_pgi$0%yJbm;R`Hj(2_X=0~RG9&F z+3|9g+e!s>btviUVZx_VMepoLs8B*bG~dG7K2!U6(~h666+LCb(dd1==ezC)ocJu`QCRuar8xN@`eUloI%B%d_UW(Q z3Uy7@;7tgo@*&P2GQ7v4DzfC`<%bqJp0M~6sTuMkqhjL>7&jv??iQ7rjo@{||J->t zwX@vtbYx@aLfp_EZOvJ&o8lI7rhz+KQdq~OQXJR}3m!}mMbqNlD(QCnmr|oss4RU{ zQDrv=V|dw*FuCtyF-or~b+D*?Z`67qteZ5aWMeIGk@>QfT=#b0K=|j4cejg1<8bUK z!)KdDnH!mn%sR4nYKEK_iEd{gr|Q++YJOKZ*zSRs7n^){(`Vv`o!(V5|5mkW1JC3B z(s9}G9%GSmr~B0R?`h_NrK9$XiQMz98z%YvIM}2${u7Kp*Fw9D zCH5Mij6c2wVeNiBVpP4soj0`(m&WmYx$HtvX3E8*x}qn81vY`sbx=SO11_0uZLxR;TcY1_i8o|Zg-D9W>6LI&(eTL?f8S?JInP-x0GUr*QuY8E(>Eyky-5%?`eQRux*CFnL1ko86 zMqK5F1U5Iui5dEF5ns(}g&U-ana%eN*<^Nb%yWE9$cZ(u*vK5>pV!IxUhl22L21sc z$KEnJ23MrvM{O3xVBwPLGhCvb5k}WOyBfk4NIn~PnEMk(3PTj$c~x&Zy!V%tuVag} z@yPJ7BF-a#@aeg4h;QQ4y;QQ=gdFKo^SKk{gG2E`f&^lFek!t(`N*J}R*!yJL(2`> zo>}w(+f(5hED0L)Q*SrIFTJf+x+s0GrCNMbv1<7#%R}FAt~k!Q*U|VtluF4!^n3XMyTCm$%j0p#WfXEgJX>qE|5# zpaBHEf_J7F^;v;$01ediR55SiCzT4ml``RGh!-ZIYvpN^m_!3iPx?*8A{b4SHNvlt zq?x~}9HtJN&eZz)Wpi}WNbdW9&r}=bcG#>V+B)?8-!QajT#*Jin}9*kMJffz#6|?k z1@mrG*-+O!aGqh%*N|3bM=F8tw!kG3=sZGS5p&!+O|Jt8!A9}JrMz+k% z0N9{XEXCm^9T?L@!`?(y*eu%YC(6foC3cs%JfkJT$-*(+%*KM&$`Fo(-$#|c*btY% zzG)q4)wl3}2{Kw(-Bo zsO>l2J;Prh*K>I^p;_zE%+^AV7O{Ng7S<9m7_}z5ltH`%=bpBfsgII3T@EGt>C=~R z?5wz;y5Yi!SvsB!G+bjyroP+)rM2n&-q&vSC+UM+1%l3{S0wcc=B1}hNg2jB^+u^> zJm*$F@{#J%lSisIOjW<%h>ts@5P9{pPl4Q2Ky~B+X=P-i?+7l)DV0WGh$n-;xsF5f zPQ%6un}k=oAI)lbE(wR~GZx93o0KApL=&Ooe#P|*+{^(Tai)mrxT`@nrkZ@zSbd27Ck zIlGf%(@C;c^nN29kETohI`>A%--}|p*zodRBJghH{wCx9vIx@Yc~?sxF~H8WFx8 zeqAa|L~l_K&D~RIAtd|9uAPeAvVQ!Lal}XW$y2?v&kJt+dL4?dlHB*n*vQcGh4jwR z=z3Aqr_%;gLQs;+WF|lE+!y@bH^16rtsqvb6o+a{8%ukUGykA+8YPlE= z{W!o;b(x=?JNj`i3v2U8C^aniB8FRVRQ=Er$xn?vLbRrcYJexDZd zC+<~tbnNCH^f!^xmuMhC8Hsew)Lm^eMf^ZH~QLvvL79VVOH z&GVtc-5($IA-QZzI;`8ajF{$2#mhG4n$+>LZ#(?ToX?i!U_45^$W00lA5t>7YJP>J zw*$=F)jh(YdTe>9Kj zNpZn#61lskrg59LY1-#b%BWQTByFB{yBpQ(sW$X%>&p+`Pa7v|bj%(lc93Lnuip1o z1b&zc%^w>2KELU>nQ`Dc2EHEH1Z2ve4P2AR0zk;f4CgPIy`d7GM^lBK{;GN!+N#I% z&?KNJnm2w@E9xfifVTAMc=2oZyE~dsV1{Cg#N2#Kiz_7ZBCCUJ0xvwYk1M50I6x)e z(Rx!;ejCs8whV2dJKfaM!6l^4Z!PcnT`lGv3W8e*af!-4%~oXJ>eM2bj;%Ttva{#v z=-bEI8TGN-{b;H52)jByc8teJ9TjLLbE4XEuua<;*cJC2_#U45iv6<@<^_tat)4aa z#HNbx`LlyBz{+r#(*26a1{UM=iy8I{F5|4V=^KZEi@VUZW;Q;vueQE`f?s&jVqKH@+2!LA#JUPUl1N+AVJH$gQC&aTuG4vd;fjIEJ z7ao*{>Ld~Xtx8x-gJMc&(16qj9fAabIt2wBI8FnV#e$A6)&cT@b}83@YeWI&6$)59 zkrERDM~s6gVL<0D=7E@PB*m41CPL*ffU$*WUWmB`eSreRmqT>)lF}%QEYvonKoA3k z!k{*>pv=E`ZAd^QWIs?#kROBMo}+kOpsqnxEEcDX2~->e7xXhdeWDl>sEMN2CF6i<2lp__ zv!x1%I3j#GK!&hr1CSm3J5Y2D>;kfh6@g+8kWDNHimfc!1Y{2@10)ya9@GKZ#&Upc z1&6?bF;f#2xD{d9CYEDJDPY`Tlo+~R5h!E5_#DdfJ3w{;jVcuX zX5T5Z^YUD$0H^TQfGt-94vfWx1Gq<*=I|Rw^DAP`J1vg1l2E1?f^?s9xYDXz#JFvz+4Va27znX(o~OKk^q|hDWD8gL0J^Q z8wwR!G1r0n<=!X(=g(5xN{}J0%DNJH%OY1gjDLSBDj_cm3Q1U+gWVKpLqZnypR@=>M1zEJOcf!~Dl=2yw7-fs9ZrExn?LD=_Il-ChB*{z+U(QmI34HAMju>fkk6 z(x*Dm(9&?KgEfA&GASjNwYi8*mcOd~$KS8y?CJ#2%5P^r960pBdkXyrpu!Rr1I|zY zWbub@F`TaKhY+8Fcd~^zg1$Y@v@5zyDM3hUtT#Y6oi!&z*6d)?L7Lr7o?V3Pvqq|w z_e{FJcWF)i{K%b~J$9^%>S!P9G1<4rGDWBPWa2k^$xzYMFp4{R zV)EE^ZRhz@O#r^@h$beHk;?D3>%eO*1|K2rv8VkV1>mkeUkk@&I(q% zU;jf@u#&!7oC+2{$CpsEw}MX7fVB;pi7DTxyMu!ylnHa`D=vVc{+X5K!FT4+!66~A zgK%+lCOg_&z&+67L~)cT+>T6kag&sgaQgj7+}YJu7$yM@`B=GIg3s1hmU6MOfm@hc zI)GM|Y5+MY)!RFh6rfN2g%l)F2ox59MqrUB1aQEMAP9m8grM+>`P9?Z+6D%o2>}KU zXe|8z^uvOM(gyxJ4VoLFADpz5hQy)3nSoVl7$W3>{&zkM5kSmUX;>t*zWqBN771Y9 znlu18{+$Nyz&VB0X%Jvrm4<@g#;P;~3ISGT==ZxHKs^Av{W}ebMFD)iIt>dgrqJ(q zStJ$@VCt$gfRI<~0D*^2hWtAp0*`?l?9lIbec&J%fY+b?QPfI!N>>q8T;Ytk@Ca4KU}J`5Hd23wtmU)x`RNLT9$fklB` z^{TR1+&cLP0LHD#2ax2N`fw<)1zDXBzqURc9zfiG=R=}#;Hb^&G-#i;It>AVF6j4L z_eczMUgzIw2m*19y&#ClHDwV*G-8d+#I@rKkpCJzLx+=A+W|O*x5f^@p&IA}$-iYr zqQEv`bs8GKHXml~dk(y4Ys#V!YitsU!a&&h-|eFCfRk9222lA5oswP6?H#RMDgO%; xZTr&zH3KN9w2;^mQYfyOJr>7!rp?qG9~}ibs`T{|obl(Nq8c literal 0 HcmV?d00001 diff --git a/tools/figures/output/base++_Recall.pdf b/tools/figures/output/base++_Recall.pdf new file mode 100644 index 0000000000000000000000000000000000000000..e01ed3434e221d9be3aabf917d281174e3435b43 GIT binary patch literal 17311 zcmd^n2{hHw_iqyE8p>3WbfwJ0ov&-2ha~eHnTO)ibrXrGh(czOAqoj0GDIkIL`tSa zg{UZ$D3y1<7bW@q>b>>eT5qjC9p`+{xc5H$d(J+eeZKA0QC1a2ieX@TbNk^rcVI|3 z0`6vW7$z+Z7dP;Bw}*=>S(B|@+#KQJI@XT%9&i+>pa++gh1nDBpo$VJ8mPJv$#68K zg1Cvc&0%|6G9177sNzLd(|AS*`7p%W1&ZwxW1FMoiou9j$0}vx!LO5 zli}uISY=f}ioF*ZF0SDUNKjmgm6l=+xWx)`kbg1&(G5{gvIo>Hk}s}j@8NcgWDEF# z=m+C#+1oi=E4q1s5fR`QgF=W&AaN3SEF6tQh!N0eC}{zt0ayh6c)$sZgenk;Zr}+C zHCibasCZe9`|XL2WG6UsX>3(z7r;4OT-61LL)qTe&CVWTz=KS(w|0ejB|S9L^k7qC zshZ6nr^<@aLg1s()r}cP$0Vkor9jK zX~LnZigGXSkAu(NUB>5Ih7Nwm94vP0o}~^u^?0y4`TL%X=hv#Xq z>3y2J({U&;w|V!-5N+%8?!j{}OJBYs)b$VbhxMpuT;o zJ}W|W#LAW=pV26KDPO+-2$Z@ZkXTx1NhpeYrZ+#YIxck+* ztdi4G`|I~=mmg>+X%soA&({D--oCZB7DW-4jnTRp6-=T z*;>wEw`a)Z{C?SIl`wU`Z^3p`4PQ6L?u%}xAq6(6z!UCHvGI=^rOnN~Y!ci7lRfw7 za&D93RWGw))Oq9(8+ofEI&Zxmw;WgfOPx<$jctduWVbc3@7@-`Cb;cuQnJpgWAd_o z%IV^q8sP`u<{1qa1xSmMIGWhBNCSNBT=wh{H!rGNxCSI<>DjyK;7tn#WjMVf~KI z6Vp=@6VE4}S4@3=^02WqmAFt@u`tqH-7v3>KlOZJ`ct8A*ZbP6Em=I*cX00TPHucn za{~7zLeg9TE^QZHKt0ejJ_*m!BeNM#@Ipm9z*B7>tmy@dO}>}KeKF$XuapZ z>A5KXuDD3@MN6faUV*p{mVj)vTOIAHdKV1TJwcT{&3xwkTiwOFBG!+Tvc+6Ky+6iQ z!XMNK0yT6zS>mtkx8d^MaP-=ET-wQRZC&qe3wziJK|M;@W21__Auc*#oGW&@774x& zT=XMYb^Mfcn>#qRPRZ_+?;Ueg%2p8)kP6Y%Q9Y7$Hloh^rM*Dyy-PZKWQ0-8370fS z`8QuKsISk^c=_<*%(I8DPfa~I%WzR{e-m4^39kvA66|1#K!W<|w0e8Ul_u9NSB$+;|_pVwlJ>LU{ z<9Sw%-x4-lDc^#f-*!|IFTIY*X5d*;uJEA(RhIyBvLsD^Y{#ZAkKQp^MY~H*o@kmuxxQ!F1Sc+6Tjw>`T&;e8{BC<=VK;?B0)7 zYg_n*ni~r0hgv`2CZ+_B^)T1JNKe)Ynk}&soqxVNN3{2honwq$ZfcwU2Tw(#d3e`6 zb?F(J*n5Ny;-L(SV?IONab*`PeBI8wax`cX`OQ=2%O14XJJ$bD;nJSx6Rv6W&eaTVlpbmN%&OUaa!sV`C;;y@l@EYAVlI%lEPCAts zm4L(Nzu)Gz^E&M$SI2cZ;uG&hUU7$Yv*SBbX4M8KU(}hpoj>t|&v`27!YENHKi2E{ zyu`%73C3X80WF6nxuXgtJi$m&54C>(?faCNu3tWX0GTLJ81*t(GL=n(!L|J6xqYpV z2a>SqYJ~*@vD1UZdsUU_453c-ZnH0{0}Nx{CJe~aM@SKp2_+Sn!oxJ>Vf`XK2riXXwUjpS?dV_a|EP|e_Yr(fjv zfJ*AF|KQEguh~Sw!HM8!+|HL%rM#q1ddFK1jQQs6iIv-K)}ByfsK}H5q>58Wle1h? zQ0vD7+mfyZ#|*o$YE{Dk&KIWz7&hlzVLDs**oHh-BvfD>&n2C!NAt8@Ni9{%=5~c< z4hqS0^H!UX3soDT=c%B7nuxBqb)P-^H~3iY&0dH4g- z(FY3ziKbkc!j>G>Zx75vyOc?W`Kp5{nWx#VZP3>2&f@xN9f3+e&M|Ubx~t*!&FIF$ zX5B3;JC8aZjn`%>y)`RPZRQ{#l`?!&%lbzn&Tiz27GeMW5t?wKKG7{|TH$7VcR`Me zsN`h~d-GQZ_8F}f2}kXuBI!2kYxYK%neUumGIOT&uo~QEVo#@Uo1SGwB}*hKy_s&? z*^+ju1aX7QX8)Iz+Lv|RO%~kF)5ezcw{bmsm!q(goQ5g_BPd;|!5My4^J^wQsFmWi zdcHMFDPK9j<d}<~Q$_XGI|@h9 znqPPYX%4Rw>e`yEm3jWKZ}kKJ%p*xfLT$WTMjt)a6TWFiJ#i*8Z`v4p;q?}HU`=2D z{e)sE4@^aq<-pNY%lvcIRJLt#RgiVIWKb;GX%y=NvPe)!q{OoEMloBUOmn!!avo}> z1+`?lNML}u@)xrO`%sgm2C-q?qr3VNr`0#iz9HXH|4=QeikEIsGmLmdI30QA)({r{ zwv*n)Kr-XDqq-egyR@_mkyN7S9Z1@jO?qGGEG#uM{!NZa+Wpl1jT|$7GJ%3js8wba z3XS|baWXm@q*??+#XS4T_4S4)E=nbH^)DQvM5*q`3f@kf`&&kx5 zO?|oHb+fzDOrJU7l`Cem4(7)ENfR|a&>V)zFw37o z%QWpNKB45M>$dOa*3EY`=17h7{S~4T)Iw;Qr=6x zs$DOGYZIlHsp<&3EXg)W``)dx+W2Ll-vZUJtQ5`Om_(=AFgW#^VyYt+$%^cNC5vQ|aG2aPry7{EfzMRdl6}J&}&)#&7P{ zjOO^1M-w#CbS7WeLzd&Sgx$1J{%5us+2=yEJ`D<5uT_{&4pg5U>`v80J*Iwto|f>` zchWQVqfQH{qg`msRoA^#{8O%X8?))gL-Dmnky+-64M(Kn6HAGspEg-B@pBSdJCjs9 zRBex_76~-da{AP`?>uQBaAHnlVZjL<_rvDo0R8jC)cs=bqujIaM^|hnnX@G`^yQR2 z6?HsjoyWf--v!z5O83#xUEW*B{%cGEbkbTi2`DuF?~{Pl)g;m&>7I?URncwAJJ5w`R4L1mrW^{_p^f2Zq%07Uip(< zIz|aIdk*nc#UGDrlHlKOrEIse7)mC60VeU@jEtO2LjpX5{ZN9-J5#^2g4i}xGIa(i) zf?g~n54`j7{ju)RjG^HgQiBeptEGnhmla2sK-_?2c=sLVXGQg{_8U;iYFQlIj0U3_Tzo%J8iIlth(F>OuN%QUlT(-{ZFR(D33^~`~S!qH{L zQ=GEL(X($w<}!m^JQW$!&*g-Szb302Wo@_rz^suFc;(D+i$=v>=N@&5v z9`!9r>#0)lfQK(`p#R5rdbi8wLTktn!T?su5RLhVWk@ATi})PDaCTOfM%+sNTclkw zzse((6AP}@Y0+D~ytp)nTO&F@+t(iM+bQd6-;>nn_r=hoS7Y9T-LWOFp}7BBoJvn) z?CmgZ@dvgShKJ4{4VG=ONtiG2m9Th0cAoQQd!>tlzZ@xf^=AC7N31qRg-?99u`&yi zN2C4Cg#BxlZLb}!C|KO4%gu3HpHNNcWfB<-3XL=^KUw!XBvT@(V<2b@+ z`Z14D`l~_yt$G}GFLoKrCCXYJh?s(HCT7o8ZLW_i{YJY;!7vvqXTno z2L)1}l-Exa9fBkbfS_Vg&5VP+cA$j`1|5{I z$_^=46x*iGQ1^hUh`}FY;>@&Pg~^yUc++EAT?QKkw#yueN}cyOwv;p6#H`yBv$bIj ziq?Qnu!z4eRcliu8ywq0!`ebsB>$5bW6pJDwv2oK92pG(IjeLPN%)7MRU$N_sF8F^ zv$BEV{t9y*DF-qfkmUx%qZ}^5J`SJww zk*2A%Z*SGcE4&Vqt{r%GVb`9ly6gouwd~jzvYW!mYB~m~Uq8k?RxZwEjJ)f0ySTe8 z{4!f?q$eMj1#9^AyvKa9!gu=)Rmf2}GpIVPA%-=KKos&HURkV8IU_Y#_GALL(pkMN za!a8?M9Ez;mk8({)_Nu75vGwEJtj ztMpGE$G%zceLwY{nEQo;SYa9?8Jd;RXL+Me=({&d%Na+04$tJ6^^a0C8bX=GJ`LNF z=nEwqU4mlSWLDPqux-1-9P#Qf4gcwcS)2!$t9X_)3N{Jb2>aq*t0j4fiDH zir>6!pEEzQ?Lh6D>W?m^jDAcmUHjP`GD_STTf*o9wI{e|dryR2H@VZ4c!|gHe$E2t zo7kKVm#w`wGzo(IK4*69&T>O>81dLxz0Qdd?L1NtAI_EUgfQA-!1(a)WaT%`!CIr5 z3$M|+b8pJtKFz2YIIf;%Dt5O$S0ejPPuGD9B{)0HLG?L(fkdkt5m_gO`Pf|59Ee`u zn|+6TJkO{0X6HRuYIGn*oW{4F*J-2D1}(Xc*C7EM5+)ppA)f-b>9;(akGL*9y*p*E zum3|Hs;v!)J;f|Zc(qU;fg`=5naAoXpRA8Jlr@J3T{#*_e>ZKG_!V2G{*<{$k}=OE zAnU$xsfOZtJdar5MaNq+T7rd=u`f(!FOP(YsvmQG{qp?6DDLi?Z|iR}$#d`9#l;g}$1SCujK1?w2pD?PnUh?>R*F z`4orwZ2cjK=FF2C*D;jBv2)3cys?JtU*k?~OA(zjB)=#qbS})4v^_>IvSS~8YTy}O zDI213juuQ{(C}+&Kz;a4`>QOMJQTOLHb<-8csHx>Z+YD{Ru`tkv^RwR26{9vw{GYM z?dy{2LuGwAGIDON5+ZA`xrW_`#{9#P1skP7R7TP%1Dj==Q}cfAwvShhwkFTJeS|-N zxwz{ZO1M3U!FXqDHbfrOuOeT&aZUTno7N9;?ZFc)hgqujitOQY#Cr$j^D^uAM{j?H zmXxjiJl<;4;ZtA3(0xNd=Cw`q`L=MsyF#(9ClU;s8bl6#Z^+)`HXW?gC5HA1-yveX zPP3r@MXeowd$vv136`5Px9k$%hxe5wr}4RJScVL#zo^0lR2VkxJ)xbq|)zyK@)>*BG%_Qgbt+l@gdEaUDK4+*BdjNzD+JqeLs1y($vtJ z>7hX8ew}(CF!K96|HL17u&fcqE zi>mgymfz1l5)~|JzSktEtF*%=O{3xQk&u}zh8gxFhB@fQH?0pJ2ejy!@8YGRr{z2F zjyK?+E*O;O;%u3`xC?XnHsDtpYWm5v?2R8@x8b&{Q4TA`YM*V}1UsPaH z;sKMNbefmWrsHD|#-eB5O{!^O@T^38=mFY~ZGjoD3c7kOOq;9AOqFIQd`vIcoFF97 zuY`J_C2l*>RPVDha?o2(6u$Eyw)m)Pc64i0S7tIw%rM07X48FxhW#uD51M#dY1D7r zensSLEO%4i%bn+Ex{Aiq#eIL6!cqJG6cci#Ew$*6|Ft&7gH)m8cGV8L2?LBNs&hBJ45*)}K@_&iz zAG0A-yoS&~Xqfl`YggcV`;YyDP81iNZlXv89sw7{08{`;z=JgH0XikZ#oa)E-~qHHf#KldWWXEn`H34l+mW4s`*6`q z`pc_Ik>|1*`fJiwNSq>>zlhhx+R+0rxG2T{DrmU4f`={SA0?mwPtY*XgD46ORy-wZ zcXfM;9Do*>EGL2Y75IgnZ54=)F7{wX#P!Mcu0~Ko%i4>=795EIB3;4L^80EI#~~2^ zHv;Vcv>!A^0*=Ar;8*~5N`SW=N5FykL!jUicoZCsz=1+MC16o#N(Hn8@ChSGzytlm z0i=dNz=C>sxCA%~;ejPZiNRZg2WpQ4+9?4rUZ_u?bZC%3f&wrA3JZg{fvARJhzcwi z911wlF&?~eSOTaIVjLcb1-+Pu2c1EObSU6JTMz>gVZaEW064}e4WLFC&=?CS!+^$= zGAs&c5eAOK0IES63n-&>20DRaK)~X>K_Uc1P(%v`#R1|dVp*mRiYY1t@rMB|kw_e# z(gg6dtUfTxi~xy3%_#v*pz!vytHlIFAk+nd04Rglho}NnE^7j$0hEaY8V|*gxRwc7 z>;VU{grmq0(uAM=L$x5GVwM7UhT)-JK!AvY%AqpsQU_pipw`QOAZ?&10;CM90$4<_ zpf51+l4dN&%bKwqFRIp}Wkx$5!MP|!U5r6f+5XV z5n#)ju^cZd#G+zA>6HN@Zn+#%faL%w2T&jw6#w!*tUP~f!YR%Ou&?OBDbPPi8Nk6S zN!e?x;b8tKDO)&D2ujKxj)Dv~P|p#LqAaQ~uo$D3b|)7&SluY4MDUg`rrf{lHf9(R4^9r~gvSKp;(?O|3f`cifq`(@xH0a_+0$x!{KDZ0BIPkAE5WGx(!V5q# z{_0(kiMYoMSr#9R3itZ^%gZ#|k>;uwj>tzJwA*Pln=KU(p?eNp>-ehq%#%WlyBmAx z5xRWdPjp)AqigE!C3-Q?Q)O94+KLSt#u;YC-b@Tj>Aq{L@YFdT%jnt{Fr9Z%?ml+l z>Llklx8lY^0W#6vj6T6c`bd&g>($m#zXQ*X%WTif?99cJdb?#LTRL?t&F0UxITe5B z?-s8>TJ8#DBhdY6d0^H!3Q7FKDOu}b1!E1IVZ);oLFx4`7G!l}b^{23kvXj}mCT=P zbq+>6YQj(~#g#YLsej3LjGHA&$K$!O33_ke*=vWswjgRRnKYH;7#%rN|=o?b|EMrzChIoAgaL5Zp$su)eYJSypEualMM2X32iXSYn@VAKlIp?|6!1&~D+&#vkX* zcCR79wT?$%js3d>^)!j~@{lv(;FZUdm)pJd44+rdyB$%|%ND8(K{gWEWtat;zHP7*^>VF_dF#QWU)etoodtcsAM7K>kn^qq|GVD1#XNu zFquYw-b`$Bi=sK~ z7MekC%k^nfF^!_DdfocB^6QkxV+`^l!X5XpR_XS0SlX;K307(GLuxOcnI2=lxwE{ijU2Emw#$8g@jJh!8SNkBXJ!wNRg&n^ zCM%e3sutGOp=6W81W%-k6mTY1s-o{%Z(;75t!<*_{7$Oeg2_6vyVzv+<=ZyO zVdq^uWbSV0+W(=R&{fNN-{nJcSK2K!T8QU#-<_bt@8ms;sy@1xq{^Yk(p}e>(`EY% zGMblQ;D356v~)GjQv>JfDEn7=9?R&;ETsgc(S_~@Oo4V|lzQZrQx5D&9RCDjbrP1B`x^sHh#cUKUf_Jp1F;B{*9bqvoygNk@xticnwQ*Nq z*#>kPm94J^s{H!vSRU2`49B>b%&y%pJ7QD&%B({=te-fq>R``zmhqyUV*hskVEDWB zx3-JK;Bc%c(>Ce}#%4w{tM2R_x))9h$8>EXryJGXXnj%iy2}%F81t*n$=T7sFYmNQ6-G0L3hStue znk}13lPetLH}&fiAQ)#CO!D)eyaHS3kNpj2+82C*?pHAp@#GDPt_&o=dSb*&?(H@)^ZrIqJQ zm)Lld@0>8?kAGRrcxjEH*TApQz{viqSg>IAhFrnzf@SgVa7Di`@n&6rgONn~R zkE{fynqb@HL=gsCC%&u`;UP7TQi?LH8ydOIW$}hz_ZfMfWxIcVt_-ONKx!Q~@`o=FMh4~{*J1=sZu=A!V*P!a_#yI2J z-WNES`^7oeT=Y7dB-|z1XhNg6&}>_K_(h9=<#*1t}db7kMi^s1=VvOIu& z-?k%BvsBsp#{%9gh9jPbHaZOCnRXrQ28mjMuA*J|(^-eCLECuY zbHfYWbPqPtGG34;o4vXcnhF;&t{jQXajG3eBTo;~obwazb~PE6H-Kx{>R{76r%rt0 z?m2pCQ)_fr0fR&C=5L{b{bThnksOYt-S(YZ#x3$?5)_(q57zBv-M0U!HLoL+Tmnk0 z*h2;nA5}Gpw~inhaiUqmF@fqlfj2H+Yuli}5_(fL3`b&{&k(;z4l)e8jvV(8bv@Ls zC2{1@xJmBxa16J_Q8E5);<;R(W^mNobe(gj<<)Dx6F1Lz+=?Fb(j0Bu`eexGVe|BT z1FQPW-9&lZ(^rF)!9(+*`JzEK~tG7nxu!Pm01!Rh($3a(3F0(;H)Ec*}nJ)z=W z+UWuhe^fsV?J(lLcQB|ph9_ZKFZw#qh`#KJ1hJ$${oSpHF{5$Cq8|QbC6(fNQ8ghB z!DsF{Un-+Yl%`VZ?zpb2w2k{wXNJB&LH`EX^QFYi&uwo9#Fy}lhQMtEI7Ad4W~;Dn zb#3RrhOM@|uwx*8;(3#PMnl~809q;|39jz3?#WmgZN3gNJE}7W+p>KVtIBTaFX5R_ zS>Ksq9-~+~8<=wsQ&&oTM;utUj`e8Edb`wpF-605ZiSzUXpO)Aph{IUi2GAMz-jnv zfBJ>E@idwq5%0?@|Sx9(4aamxgt-xg$k-iluAy`W&N73JLS@;qz`$yNSkwS3% z7R3P50Eq`?3=#!S!vJoDpx^%v$--7ZvH&px6#Q>NvXGm_-Jay^W(NoGgZK%15`a(L zhyX|?9|tM26UiQ=9l!@1;IFeMNP7Ty7{Gbri!db)G*SkrGr+4LWER51AQ1L9#1)`K zkYHAv%vV|pzh+{EELR||e+Ujp5{r3IOkoBMuv@?oLrI)CJ>G(g0FI z0e}q(SUiCe695#5gOERf+AY?B08AtWu|^Z1b{GK6L1-QX&;eec0KuXV>bN8{3h)Mv z4G9p$K%g*aOe|>g8+-{7h=lY98VQ1{Db5%QxDK%fb+Oo-GAB@T5MbD)Is4hwVge}q zViy3XgPKvO0+d7338}&IbS=k#TG+CbC_P{Rat+v~h!4_(pZ!C%Af7171bU*#bxAp( z-oZ1BQnu6q0Y`u@2S^bXbpX}t0O`bXpyLILv*qr^~pMWD>}VmVajXMpqq znpG(N39?ewH*$3~>e2f37A5HkMx%(f_d` z*lT`O!~Dl;;KM>nMJSe)UBTfBQ|T8kD{!`7nJa-BE$G!<6rvV*SC+tOEg+$#8Po#X z_u};Z>VVQ>S(=NR-SYREzg)PgZX{P=i2U@=!hw$#yo}I4UUp@_RjgW$j zIc1+z8!-f3*CmU-m~z!b0loKwL5A0ub21dH%lUNVth8pAB8wMRV&r_K;tF`}Wc~KCSkU59 zu=t_3xTdombU^}aV9-iT`S$r3xD-HHFqb|m0>;v>RapsqzYN_45den+cNaIZi?a>f z6D>v%Ly5qh$Ygg9NpW%4pG9JBBu7D*IJmcAcgz-i4!yFKyPX5v#@hA>7-gvkP?OTV zvl~$v`iNXWSrUanVG(Eq7KuUtXR0tlf**n47hJKPdXelMV8H&s!+^7W=?B<%Sg;#9 zz<=gJYa{dnCobiI%WS~*T9t>vf@>XryJDe%>eYG3KgLB!NcQ^3Y**@KqgPQ|MJ zFi7C@U!4cpSF7{zf65*s0S+>&>S2+PQ3d_{lpg|%fxOSF^6+cagKXAS^>9ewgkPP9 zf&AaU_XkYdKkI@2ydMhm#&@rV8)k$}6^^^hpY`ux2u61w8C zDi4Lg{xJp${U^=fp2UjzAd{?}UF=De|42jM`2?_u!Id#RH#cYjrI@@LL33-?!Gc)_bkvoaa1e*k_;p?7h!#KaZ%snuZKY77G*2X@h6q zg`wa`xQFe2n7lk3VSL!r0gg~5)5vZfPH==i*~x(lM}r22a79I!1H~R{NL~TXE&!q%qTb5^P&Z3H!q9>0anQ>S ztOueW%&+HQ??P7b@BuR-!5u1<_t&8&8^|$238J7Xt)7(s5#hq*gHTApwhe?$nG$o#9VVd zDi2}Z{pWdcBK+r@a+_N?AH9QR@v~ijBHXP+szwFs*&1~Y_P976YD{<=O<+-Pb=;@@ zHm|hlsF(Nj#6Z*do6(oUrw>(VPt(SU^^+_+2&n=)I(yz|b&-52jAwTr@;67tyV;yO zcX?b>v)R7KaIDOIYWN&^l9+feAWqY2yO8JeN%?(e#gNfTJ!!+Wz2{EDb$wK!iDF_oUq?{s(3^0Cbec9cF~BJNLf7MG z=DSvNnH{R=``|f}pmDS7!=oB+ z<>HbRUA?aD+G;`E!TH&YjfzkVLCIGtCt z?~72!jwzR8KaNcIbWQb?;9JQ)8qKVtW%FFtQ9^ma(Z>w*;Yy7C%N_s;#Q9U zne;Z#Oy|oC?|JkU1BlEAFSU0mtz$kLZ>cF!{QQ(e-1#c)Go&=Zus%dW>OOqDZ{9rK5P0urY^aqWd9nH|OvJK4``5?FEC zrP74D&2}%3-&V1YFDMf0j?j9?GFo;d;Z&-R`?q`xu7+tcs!mRMD*QyZpOb&UYt5UW zVdXfo@@YEJ?%N7EkLQoR4qpk9IEw16e*ZezRL*q^r`J&AX0&!F;-RF0u|h}J3Z3vs zvyO*13GB%S)O(Xdr#yD^w^pW4Qj=?6N_QCEF5B4@F23WQ+d{M406jy)K!G9AVD_e4U4dSqU_=^NcxQ8l^ml=f#&Mi1&urDCDFI6XBCtFM3_ z)2Yk{_s)pLaSDoOJHyiS+jNIW356UnwuK>%Gmkr+WfQgI9D58139>egYaXpUw~h2( z#U|>Q8s9Eqfz+~OBL{C`(OTt8Gdmi__)0zbHzN+F#-p;W+YWoakL47ihPz{_NR9#l z!bRuuW1$u!Sz(S2jn#VUqeB-e8gdfFA05~cQKMfcDV8+VWwh~FUzcLAc!V7f@pKP4 zrGJ%NzT6JYo0X4bEj9UqrY}DCGhy8OwuJkwKwF4s*S-PMfp6QgvOBxeFUDqb4Y@zk zyQpx$GwOoyOI_(a_@wIwr=s`Uth`Im?A97wCdDUf_uyqa`!-rOx`*#&->i`$E6*pN ztY@dKo!ZPqqgq5}_eSH>%h%~?Av#rdvLz|LY*YSWo?@WiiWdvbH*PfDuAXFY?L*FcaZ1Q&0UOVW95e2+*DplaM>HgHoR@%JjIxOp&&qpEuy!{ z!Ntt>f@_o!Z^9W>?b?$}_0r_dizuxaoU;VGp?L zY?b6R?qgl&x{LRS-Wv6&Rbmo-7JiaEcPjJNIH~JRtBqB^gGFjMW0O=Va@!Z~F{P zk_|P61*Z--oUniTE~t*c;Zc|O<%4Gp|B-?5of&SoNyWw2;$Z~!)Xu@LTun!gzo6|8 zGO-dBDN~Zp=T%S?DL3zKPG$0;X+=7YtzK7v4{{tc7PGMw(@DT|G_BlPrbI2OnPvMACSeKGct3wn{JUHD}}@9GXIluajnApy5AzA?&1kqS=Z^m|}5Xmq3(!P&r0`^(!HunF_E63eN^unHPj!LnP z9dr!N;q^AhB({z{2ok*@t~=m!dw0v{9aRc@{Y5T?`Ry01+%$UD0v?5u5E_k(E-QE^ z)rlQrF2-}0KWMfHzt5iU@{EAqC0gAp71x;6(aXeg|5dV8k(2RF-fLR=fiRUeMRZsw z=IqN;WHVI!sI-$=_-MM%m%V8dJBmtYHuPIx7Run%O<2lC^g^Fw&e#*Me`gM+0BMTU zzkyxFl@SvI`=}8ElMJ#_QVfo+UBg|wvOrg+&QDr|m0@d4NS3hHhCl~58sT|*Vo3FJ|gHA0HSvWi79IL2GVG3*hm&d`!(fV}FQC2PGbc`mWN&Mt0}aXM(3H$^M)@3rxyNQyacYmZgE^MUK-#ZP6iU%EtFDQ(tM zKltCv1Q6P#wk7WWdT@mOPFr@4W1`T`vx?#OJbQYC4zHt~TxJuXgT|swK;uyVun7hv z%4!tb=V_R)&C1V((@89+cnW4bR9zx+wY1Rc4R`QjZ9`nDuhqIMV?$H!t{Ob;@=oyD z2zg~I$2bd%zoSZ>C%eKb&Ok}2$l}7`&5DOG6Q86fGlJZ_RoGKbXNUBEq-mL6-Q+OH zsU079#lO2&yF}E5FOOk+_JF0mh;^kA=dKr_^^};?xHpMrSCwOLafD8rC^r^T)tdr4 z1BUDBpH(?uzp+<7I<7$c6!(Mi=QlAo&pb6866(hJ1~Iwb>zPIdePpW-e|R{q=-_4h z#C#L$q`?Lsc5au9v5z0l#8~g#W1SiCBT=tOwXmIf9N*sdbDGuTqUE+_cnF>C7V!{+ z{f7riJwlK27ReScp|}QNqx9o~{Uverdi5hS?xo3*8+?5Db-L@$H-2@f*#BBU(cR%y z;v?U0Ce&8#De8Ks+T7~GwjVL-uO3C;3fql%Vt2N?>&$^*#ai3=sT;?M*6(O8lZScU z8=&ExA8x;FWJ-i=ikLu@fq<3b5*KkplG_{g>KqvI82>s!n@*&a%fKk~AVPM!Nm zC~j((0FpQUYG7097taq&JI@`$lfJU{QhQiagHF~N^4h->GE*YtajCz!Se(?DZgX*H zRW7=t@;(`^c$1&)vtmN0v)O_6$!As)mtWqg8lgA_Tj!%-xA>-a^c;Drwfm#oW zGvB*pP4X+sluOQubMH+|HNL)_92VPfXMJ|t%-E>1$K?|%BMAOo%k0m>v-y9=F5yv- zT|%zp(#QaI$&rz6mmXw=+)nDby$nUm9g9i-V3^d;>qG!QLUlqhF#M$QBsFPQx+Ch&3?BVWrvIu61&1uDxRL3) z`a*^0=Yhw;WvNk}rKpzJnxdh)A{Z5URWT(lm}cd*3dP>7k+W`nNo@#A!nOxWx5^Z0 z%?ET@eB1ZjL1_-}=bSR>^OEa4!s+^WIQhpX&HfUf{a$Ii-kcTMdbKiZhDS3i`kf+6 zI89UE`11GBs26I5IqVni``jvQc@}<=C;Ec77{4`lcxLVkF-58SulJSgSn0y1;k=X$ znE$Ceh10*o&H$8;Lf{5wn@x+d_bb@9>8^?N{Ai!b`tn6|-^}qZ!(S*l-*_n{d!pn*ucp7=d#!TY^kJ@A zeXWpqLpfx}y6wDJ^W`471x52H4X|}LF^*4;NlixIEExW@!uYZlzmwK$ z)L1E_)cprrf{bmJ;m$Jd6dL;vs#*Q~c1Sg=9=&{Oqj}l)*NV?qpXF(tD&oX(r)^oO z-CYp7RqS?GI#oCp?)6?LsQRJYWMN5wGFkNWd%x$1%DH5InC`IsFg)9895 zHk?1t8ELxCnEmPfp?g0zcT|{`o&AW(nd~e6)RbP*en=~MkL>;DImE2HubOwAy^Xim z>Cl=ql1Q+*cK+&-ZZRHrO-G8)bj|UuquysOw`S$ORej`0IUhhE@EK{Q7E^T6;l&CY4>T10m4Ie}N zc!}n`2_a*F8;xq;Or6h^AK8*5di>;5;guVz6J8Z^B@#44k4m_<%48g@y!Ucdtcju} zJm|`S3#|8(Cn(=>#ad0A1#;}UZhlwGrHZsw&JcuU1J60#9M_Y~my3R9K5_9wn2gpz z_m7=tW_s}V`+lstsfW+(pX?=esndUDxJwV>+ol%&|D#Urxi5(ffUSF#|teGxrAX3`{_xJ3Ehvmw@)& zXbr(s)(bu6Iq#lCR*txtVJG`8BmQdqw=K zlTVEU!*`N$zs&fW9lh1VG~kYGTxNaCxTq)`@$YVly1p_OgD`W+6fEgvMDO>c0osvQ zKWAQ--G1pW7NDz?U(v?VTkhROJ31DF{YoATTI2t+%)J+_eDL%oc9Ce4_1|MoY)q1w zG@-q_k?)e9A!m1xRhn-n>*YXyk?pn=vur(>M2E@u%YIei*Bw&1;;1T{>S`i2u6>>` zI=MH~J=y@K${`veeht%;n^W2Klkwy2(tX9RvlVuDxD%z9S>|#sF!mo@U|fVYMGeKQ z20Vw2&NaG%&qh;CH(Z+X7==H9xp^9x5IvqmVGn2NR9~<%dO%CNmbUv_U)^BL^WXul z{ag=3rMHSX5e^6CiEtXVMQ(bJkyEVr+FxheaI~t7t>v18!bjW4Gta_(?{ABCKN4^9 zxLSJObamENkC9;2W?9V9a6W1B3Y{Bm?<(xYpJ&-#J;HT;{HA@vm+;rcmy*TYwfBZ} zX}x=Z^(!%ZV`sL#Tx4{RF=``*=XPG&75R6>tNeS?L$;@sf5T9MZO^aDbqsA!c`b&r z;B-$`(^_T5x#!cRJD0x4-XLgw8s>8n)!Dm*yVur)j*_jGtw8@{=3*j?%QYfc#!}QgsJW=YHd(v+g=@(8Rkrc< zt|r--M}2iqU-;D;S_+A*WMvfF_47oz)3I|Ch~{kH&C6`jG7d889|RGD&(}wutl&`v zPJkdUgUK=@FXJ|%fa>}yn4o@6nL||=Y<}V?5oX(l(N|dOS5qF_AJ23I5TA$cVjO)I znEw7o^Q*HXmRbtKMOpEqDL2-}Z-{i)ExGcO%EE8DO4F14>x;gxn5M7sx#ROj)Wm}$#Z2i z4hL?5D1({<)y~Vsljh+CN6iIMj1Jn;=rI{+A`A?%i|h^}8~<^9Fo?=*NJ9pNBp~53 zSa8ZjkqF=)3W{dH==122Xt{PBM9Zti(W2H&NMh4NR=8K z(1*}~6M(VOh6CadI&cU!Fou%|a0D5K0+b;fK$r#GcLEt`-~v_w`cU8q4=^0K0L?*E zg$hT|z|ufK1!3l5Pjd!j&&HH~hpg!AoIkhx-m?V^r!(d++;t;6QNj9VnfPDw8IDk< z+Cjl55}E{-3icJckU?XB!c!%CYB|u^0EEDNz6h8Af^IH$$`mIz2e2UsBbtM|DOAuS z`_PvKM_~a+7p!UiakYfwk;wlW3HCo72ZJTTu{b;&2XHYWh6u-#@L>OtXgHC8hGUR; zaFak!I5dXd07C=;Hx!7pfgFR!K^^{6f-Ss<1q+1|G|7Czfib{jNH_r?a0EOWNE8;_ z1RNn^=!r;Xz-J0~TSy z4B!TyNT7FsI$=R)9H0yfI@9mr&_If?a6A@J4azt`8GSG?2$TZ?X7>$ZAs~XzS}-Xd z5JzXrJate`7a?eUSkM!N!V~CSz?$a82S(o!z)`3>Jz+@nwf!1uwg3?b4S^&9%An;# zQ~@gIB>|EE`o;l?hjNHr^MuTffQOcZr}GbzgkR%Btstgi=MuPv5ui~(f{26eL-%lV z1Axtede8rZq=7C7kT5JtC?pXF#sU-1NydCWFB$XstZ2?FNJh6B-{mmJ_|x>AJ!WsIIP zPA+gXJ(dosWT5Bh?L5HqK3nnxarC*;FC#!du7KG40$TtcJbFhjAUm@qpkwDIowXxi z&=Y#G*F#hUTD4c)D2{{&nrum{LPbw+Y*n3Tu&bTT?TGrg+hY3DxtShgQFWAX;n8cV|jlP}` zmeg|JPPxe?7RT=X+HWNH+>UZwd+N~UenAzcdkMFKGhZHvD zW;Esyyjoinh6p!Ny{@{pL9IvzlCELkk^9}dZSPfOU#;B2ew zlO*L=y_-=qh}r@W1a{8k{L3`)OE#xrlmq5$h1>aaCmW68j`s?3Md(xCs+nU%Uk9Y^ z`(BHzh%ux3W46|F`f0MEqVl>#bFWASx%G&(CEA9eyiw>+f*Iahj_1ffPL=Dm*L#*D z65slA_nY`he$S$PukCk5KcWfNO(qsT^CzvpJJA(_XVMFM-P=FG?XnGFSR!C?$@d^v z^xI86o48_~PVjcv&wOM0dD>#jG8|m)bOdzl-#KWgL#a|?e$&Hab>+p-#pj0&P2S#{ z@^Do(%-VJ@1XYbZ{^b0f4?5Od2Lfw~_FlJ?&p0M)X%T-oxa4e@sN(Im*e}Ajt!c9P zL;=kaR`jK9YMD>3H6%}}!~`;(-!+uXT7I1LvkP(ZVapBn;!Xv$@{O$@B~_{^2D}Me z%ADk$3+pq}^e^uz?=blCc>=v=zei{~s~!IsOW_(7cdg1*pOjXp^+y@!o)350%3Zqt zt&=UkO;5wkmDhQV)nbnJKMKxzmKl0atx|Ef-gk)@T^ZAAiV^R=A5o!=$7pM|Wno96 zM(Rxp>RHW04s(4a{m`21X?7D1lA|<_gnSUfRH3<^+31jLl4H9x-@Vy$kn_60o#JOS zKcVPmPu;@LzK_Ru|D?UKczW=j7jufz2R-GX5{~N{`IVJu#Yg$j znD2Qq=a53OE7zX<`@SYuGJpTC(^I&YqHrYG$FO zm{wk^dspzW*_)s$Gx9F;`0#gU(^#9p2Ay$ppM8Ofaar-+Um{eFbn0#2QYQ#YLvNRn zx^=Yj6UFcQj#|T^>6(d2fxNX(e^@0h!@*^ajlkdd>qDU)_K>p-&b)fYSeXR|d%asM z^7Pl4uyyETHa#3)0@kqJKCB_MCR%LodUK|oc5V9Sl-B-_%s4$p8niJi|p++cTUg4J$HM1C167oG=y*`R-b-~f0D02}~b?n)M!6k$+& z?ueGePk7k9ax*M@X%6=<;Xc*)`cz2osnfSb->woRCQ@TLB~ zBcZ?qA(_>VX4tIu75fTIP`Ef-&hQi)k;%8 z{8E1^x#_{D)nVfqhNFXPqmue_-UoeNxk_cT<5JRqo3_L{(odLwd0zp8!5cxL4-k^i z2wEBeual;S4A>D1=(rg@U8SVGLQSiIy;ukC861W+bp=KFG8w0$rzj6u( z6#5zz^PA_A#Y*WzyPu6-yLR>gQ0UzE*Xtr}x+D>MUg$h%6n&Fk(QPR-A)2q!er_P4 z&ai4?LjlionLN$soJBIQ)dt)O=JAO2TN?!4@0#$3e1ch6sYh5!CZCc$(J zT)#)Dy2KW3$Z%@wE2_%@<`6HjkP~V5k~+k5(~`#&_QcnAMCzp13+V0YU0L;{O8bhn z_O~nXaXZx_Vm}|xR~ZW`i;!2;MAe+^$A`G5tP$%GOc$-I5A zBSDMt(Y3~X$4)P^5zBap7zmm8n{R?MoR9h*ho7Xp8El1XUHDPTp4H03aDZkQwcWzl zyY_UT@vVN^x5DVxT!ijJ4uO>{HQ8JPhCy1H_e_PG2IJ9rML4I`j}6nyyVJT4)STK; zk$#o+(YM?Ac^@7NoD(!(FOsZQhOVmY-5m4m%MC|M-x!y&b3OryQq3}t%-0z5y(JEu znMgh?bCWUnetgUsn1E@>24CgjE1F@)ziu)%hc^uHYkP)<=wCSL?e|)XjdWDn+dyax zBZ1=Xea)xdrtBvg5l5=JOB?H$tk1vl^HnIJRBl_JrNth*7m9oXQPtAuz9Z&iw9)2A z{;b%+lpj4!Y+P7}u!w4=+}EU_lz0PEpCZ*r_NS;v!lN8X6}&7vzsvKKo)hI3hzO#-7gX%Mr7p7h2AgBf+8?2kZM{{UC|;+c z7Kg@l{nmL3@yazhR+XaM8+Dt=B2FAT;?c5&R0RUON5edpe4b*u8N(Hh4b&12ymm3| z*=l93(CZpucrTu*bi_GYkV#l3s-G(3zVEpn(Y3zcJZGdkO3?a%toTMm4*%FVo?)Yb zOU{UrR@pRV?Kt&jWQUJV&$A6LyN*7s87Vimsk+!gQNlO9@3C?~#Q^AK;L4{Gm@gs(jnZh58 z6pzHqCf;ppsoRh3i7Aw!o-Dq750M*D7UCG}|Ij6_cx8h8O4XKzOas-8g7uB*MiMvL zRx6$~I`;2=K2W>L z{_@VKg6>L>{ILQ?v+qAwUe+{@S;`;qtpAp=-_Rd?0bHx1XJDXax&W?)nD+~*1k^IucKiq6by?p1zW`n_ z0B2aRoOy@u{{!F^1S7CuC$RWg&rYy*eedeJ9N|z4`6#d5yJTZ`UPArJV1GH5EKd}07junI2cF( zP6W^>1epSHCjbc~K)@CRtwK3;AHY6%@HGbk^oPbI5do4$oUMa$`eHBu#{~<52!R#_ zB|La773zxv3!fbWv<36h%K%hD1F#KBI0A{DlK?D;hfqL(;LWyyz)Tb!y~dEBepmp~ zLC78i)`7J^34%x=1agjPbPx`j8)6{H0YhQXoH)?u4=58N5CzE(G!q0})8jI9pdDHo zG{kIo`kp}DL4x7t_UzYCvjrgVvqJ!+4(djy3Q!JhCnN^*+clp9YGLzSqK|+DC^lF& zoqdoT{2Cu>1+9rLOkgBBU+07a8Xa82==bIZAmK^y`2-2VtPDVM@Mog)8ps7C6AKbu z9w3>RPjp$ClL<&37AA-;`ZX8>l8yNU$qF8c1A90p1CSgnN+f_w)8ztKMErs@FB9`Q z#1ydaFnSJ^7bN;#&)$dH{7R5qK)VX%zW`UdGM?9UYH&J02_RKH*_Jni{+mLBBJr$bm4< zUIE=s_p^Y7psRe`?DPxt8E809u?E2O=O)vDgpEGN@1_fV4Q15AXyC7py%4I(pcBH5 z3=T-3BSIPdADaOS^1B)4KQ=?k0&4{tLb;;i0!6$4H2Y2B0@USqcGT82!JT zLfy|Ra|FQI)F%|w5kuNK>JFb@xE9CjzMn2)IH08a7F+*7-q+w~#L%wvvdOJ1I=c$G zkGwqp;;wG!rjJcGB_3TdkZFE)keKvANeL^%-7CBb$Y4eMCZ)^-qkMT{o%fVUZb>l0P*g(LJng9LDAJ>5KLZZ5WP zZ;UKS7A*~TrqMj9atMU`uUoPnUQUuQ1bFwr{-7QBU~*wEPkTqWE!oZ$%rZ9uXh|R5 z#e<>-eP$}5CWl6%aY!^8#NaS^EJ+GU6h|V(B^Rg+A1?<-7%-f87>F3p{eWYF18Uq6 z{;Ljpe*pTyDRXrwED;#CMRgb=r0oA}hgqf_772yU{%i-nHU&CqaUJlk|EvR7AaJ_4 z4g*;b==Y00BpQidJO_xsBbSVaMnd-HkG?1n$OZeos15_(8ChHh8QMSVkOUMkjEm~< zBtY5XIxL9FEvh4;z`Gxd>Hx`$=S8BSuTW0AnqUepc?n7^bB3ml}y?Fi7L0sVf{heIxHhl8A=#qDrVBzI9?0%ZOF ztiz+0jt4$zUfdUj#v_;30YP0f9tw?wtl}SYpfGswlEt5ONFrt_pMjshWPgc-rTYuM zx?ZxEB-|4DLxS_b;&#A)T_RWDU4SL>2LkU)cn@AF0&g1qxeo9?1Ms{T)uFK9kg~Xr z00qIJ->N1X95}+z!vj*5bjM7a;s`u5z}?xiGL=U5 UqS4h8c)Jyef{BW%@793*9}eR-9{>OV literal 0 HcmV?d00001 diff --git a/tools/figures/output/base_InceptionScore.pdf b/tools/figures/output/base_InceptionScore.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ad8ae23a403abfc721192a887f950e58493bb3f8 GIT binary patch literal 16380 zcmd^m2{hHw_ixJ7HI%86p(`PUJKyV?=L{)xO3FNhOI$KU$q*TmA(0_u3K=s*5fUXs zp-3nqCCLz_^3M07B)?z1x87Upt@XdwanARQ`|Pv7`|R`C=PRJ4q#}wE!@>k|d*L~^ zVJJ8f?r3%bCM5+&=y^Co zj-gaQ7#=h`VQEf=6BZwp-O2lO$);pWICe2Z%alyEbSA-Z&?5|?Yh!9*OR|RJmr9)- z&2=rwa3e6Rk_sTj(wz)Ps5t-<6qaJerC1Gayn-CmpA&%ShNyS81k^2(kI=DnaddSy z2lIjG2jgp6TG*N@IJ$!ok>C%D#EKD7NIV(`$Dpuc7$P2p6Gx-4crX}35lj>GEahgPX-Z43F z;GF0X>%P-3-{vQsn-S1%t}$MiX*HcgCRCrjLU**1+o`pos8SG*vJBol9%7inSv6`h zlILDs7qgk^;zRe$G1?*1pvOp>8++0978tJ`6;E%rnT%?GaiVfO&Gn;pp+)Lc#Flzz z=7bRX5K~7U>-E8%nXg`ZiZ5JZB)WVzeP8fiHn6`L<2>H{jdb()ru>`db4)d?uCeDa zuzSBXv~j;gd*3E)5Kb)=E%a$XQOdCDjiJBJ1I4-!De-%=W31aM6%M|BwdvDDoc-n| zfmve8fHLo~#x2R1_4MNhU$c1f_Eh-VtK)86m_Z2(R|&oO%D6e)_gV4$Y}u)@Q-tBq z@4x=|{_I8TsZAe~zT5NOy;&_)CHbPP!qRDXsh-gAi6TS7zD=II3Ks7#MhtKw;zC?} zP*j)2;hGMqL&YCece}?a%Ts9wA5fnxsoyD96l5X&dd%|>$fik=;aAox8Z5cuuOafh z(DgvkNAHvYHWBqwxRIF~F;-xENXtN@LBhr2Dy<1L6La@#^@qGaqVDS$rfSdUAOfu8 zKy9pg;l?`;Ei($(qc_i=%QZ_n=R`1z_4{^vi|k!zZX4xQcT9R!emz}1t`x7B z_F$hpRZ2`s_{5ZtgAeXh;A6YN_s;~0o{6@y%yLq}n zTbf3%*`&P?!5ofNm+K;tU*zN--O?xVMKt&HOB5QTGVA)J?s5iW8b;jDsEM4X`x{x!>$`hkA2EkzPWyThKj zUc^R?2y+q3ZTguPg-SX@do&x(7yJkP#f}Oc36JjhWXx`7-&w7Rh!>I z--G*-*O!0FPh$W6MbK8e?h8Y;mu`oltXsE~%*DXlZc+!V%Queq3~#f%Drt||Ov~S= z%&Vc~xA)3}_HmcpSGt%c2U9wcru5VGdA^)%)Q{ZpN9fuX@5&)ryn-Pg&K zDUCm;I8BU6P@wEy6xCJcL=pYlg6Fv>M{dZBi}CI_RmD{Q;0~wsj#FOW*UeyAR3w?g zOA9ZYA9x(WT-WcwZg=PK_C$Z$aJ=vMbrU^!z&RQlhRfxN>bFFK)w4`)IuI-0<K>MitXsU344UL4ujEH1yT9&x@)`6_3pV=C~C-2 zxbAg)5yL>%{`o!vElb?_t;!$j--)1G>lzQ7;6ZN8b8J`_R?>lGiWjZ%uEb`YzBVw7 z5V+-GfNIm$vNx)aHWn?Vn=;w&QT9EgdAxP&iThjH57|^XpT6Vc8hhMre&SVUqe#BW z4WEMcnCOu=m=xog-P4s0vWW&sPf^LcE~eOLW@gDQTfy^t|(-KqGJ4<1i$I_D(Xt36k` z;jiDLPdX^2&w};CKDYna-mZCMI8FA}=V1>`GAyz*Xw&nSsLw)OFp~evu(b5ZV=n^+ z)1B>ujrMT(b4;~A$~D;D%MlV-{o-8qOl>wBPs~v!16z6|3*D#+lf@;at=4I@1@6~X z2`8?>C4{oMaF6S6T1@OsiV`@LMqB-wMYSU=e#jv!Kc+j2DwNrGf2Xpa<|(}l?;YDW z#8|wu*LWbPZ+X+|T4YtB{+qV3C)Xx@1BK2AXFo2!;2VCC6ACEr&4SS_kTkZRbcAJIlUcj;Qu zAP)YngTY=;BK?-Nss;IASy>k{p;W;mz*#-pdA!heujI_wH(6#Wr!%#5oHJ|C25r`> zxPrlA{((3hNs<7H@%;#UmE5}b*4-l>_sl}ZzP}K_sR}5Rs$OuLuT_w|D^gJ0J~~&L zIj)=^cE;}px9mXuOt;`ULH|uV*BTp}BThAEzVHp-B$nNF6lak+3?!oR=A>nJ~2% zcMD!4H;Oo2NEnObRJZUf39% zXv`LJ%7w3CL54ONy{)%(Fodt0XWV4hVCJpHl&Xp!7p@pNUe}{>C&yTaS9tHT+84`x z=_o)$TiOM{$H|W1EcKea%%iWt{ImBECtTNtRupl5jkew8r{Crsk^k!T-bZ zV2Y7IYLPT%7#9~B7k-TZc11B=B!4V}VDtv7xFc6K92RWIP!-t#cd#&H&g4U9IBidr z^xMm{j{OF^+8Sg)7VRoBFev=r#Y8$%BYX(Ss5nau5NkU-Q#?4Omxc<7pLZN-k9(&q zC*O6R9dqbKf%^PpN#9dX=K0Ti&5b%rst=8cYd04)3UL*VxhnT=={x=6bUxjocgot5 zu1}?+cnO=HX+&{O=g|bdY4pqA>mtMXS=?gMApbM_jLgeDv|bJJn#0O0r~9f;_diS4 zKtI_qeu0+Q={?~VGo|&&xxIDITL*2YGQ@PQM+?hgy5q5R2A498;v4KFW8=$6BhwpA zn1!|zn>!Lz+EvW$REl>u(QfyuapF0xx7%}0ZDGL%bNPqa={|;+@yYwe#v`4wYojVR zIUBJjGWO(@cZyoO8s*QEEB)S@IFW<}ZT9rVDptHP3-ufKsTw^c>WvrYf*nf z-`+h5CqBDQGF9~Ee=p+G707TB2N9vJ)1u-LSXI*u8pSJ@0vtjKzKhmT4R;!)ZO1ypwtw5ooLv#F zTg2~BaOYAf`AopO7`Mv8yBWM6oo4waob4;d8^hU}PkBiYGLI(Tp(77XIC}y9T_;x*cN`z<4T!pvPD9nzH!pWc6TN=+l=Yq zHy17&s~;3Va!SSlmo;7#~op4w2@Sk=GZQC>qy z@3+g!&mY9x3Ok5+XdW^&c;RG_%p-02OS z`o&Xt;%A1pE+Y(SfoGd_I4uTt9FimCvnoHiTAtLAZgO>EeIB~M{*EbJrih#Iy-fTo zn?oo2=30*KPJUW(f0ASsWL$`X-QxN#J927v+uUs*-JDN0DNP0b*yql3Y7e9G7>^u5 zUP|A6_tA-Hdx=}hUjzEQY6|5;`RKxE5$k!Bd!NK;g>`bK?<;hxq)B;Bnzzk8bMdvl zfy(EnbDspaRB+_m3{kZ{P#x`He_>PvM_?>{sK)uctg zszyDFZ7Q2+E`zxsGi6e*1d)$AuSYQrspM|vsC*Q>F@Dcoxmz_d< zRFcbd>5luYl4mX9SJ`7Oxe0O`vxR5nJrR`Id#C4kr7V>#ql(R+Y{0BxbfQuJuz%yU zDws9^{v11a$Ic#@z*lwZwzEW3HRY#7Y?`!fR)lLN`Ox72qlK7HMG4Z2& z^nz4^j0I=gByQz$iFWwi{+h9nq2(*>MVnG3yfur#n2LUO^G8#os~+BZ<*x{XzHXm7 z+#JTFEqeFUL^%yxq%k~w5;MBsJ^pE&l>3E~RCzdBA~Z9-=h%(NZ8ANg4eak?c& zuWw6IYY1f)n;tTEW+;?+U>_L6F86_PsFQYPZhG%rOi}TtckA?$Rk^KIdr;HWw0rN| z-5IE7vIck7FpSaIe^AZJ5t5KIRl?fOnCwV~qJU4o^9 z=`MV+aOc+=ferU0=1N9iwayvYZ9PyoT0Lc7#^l3n-?g8^D!tT+=}{Pcz`+l^v)!Iy zS%$Y8o#~Z*r&Lo8i`{9>70H#Cf`3ymD5XHpFO@;z=lqdT2r=tIM|m{+>kRqcsgLK z?xPp;5m{1`JCg*w&pzg(+R_mJyo5D@urJhWw_UerhHHJ*)Ah0XGDh&gYbP%;+)0@w zeZ`flcCr*pFy-0%X4dX4Q&YG=;1dhDXk9d;DO@NKGhjG-^-Y+ls;k5Bs|yPw_&cND z))&bm*D!oBYZ$($f7-SW5}DVdHou>OomnTG#Q(rcF39Z5nGhlR3pRqEIQAVN=Z147 zPN+5ax)*#;>I<_fPxq5haiVp5-B-SWo2!d$i(3Z9Kf!^Co!^;9?RJa`ZyLj;5yRYp zhf(#D_J=T+9`qKaX3@?n9XWZ_eMr1|nxkJ0p>#EXWajiatvhF3-1D=K^#a1xiFxA- zK8L(+jnI8?Kz6Jl`)gR#0Dt|vS)!~Z&$@w+zH%OxbT;zs*Q5_*&*wiDdTL6a`U(2) zmn*F6Wqw=hHc0lGzKs2B`Z18k?`e(0TeQ6E`9voE7=4bfm(OfX5}ng04-^#I7G_A8 zyE2GysWT)8`0-1ckq+f(!glxTe@*tiAAZv^jrEF)!nWq7D3u%UXLZjW%W{a(hAA=& z1Pk53jO69k5B{JXF0DRZ-jgFO>*ydZvc@#muvswxV*cx~8yBfYQbN%y0n1^lP1Anf zmZ>y@Es67vQ}Bl{dnavuamRkqLU zewSF0JU(%>>ae~k=LumSx!UL9L1{YO;Td`7!|(UvySj@__jQ}Q@MrBRDGa$FxQip_ z;e?s3sdIPTZGp6}#dp;mDr&RsB7;PY1PlYa%G%9R)Eb`H1A$lXKJlaUz|mB<@ju0V~N%NzyCN_GcsVz$VS0nJ z$scAT3aGBXF23?bd1ggEu=ok*`035sr=DYtpU*o=HZJu0i?@azpq*+7NPk_>^*m(K zNLBh%S$5o1YQd(sJ-d4q(GN8d<{ujGd+}WA_s|iA^Bl#MoOH;JYL4v6NJNY22m9P? ztkrAS&wBKzp__@?eY$PeMEqlT8+%^yT$t%9ew&K${&5(N-v1|?084&NX3=d#@dY6i zOkFG&Q_2Qvit-2dDp{T|HE`86CAo+yINDo4({*+sE83Vk!;vt=e$%Bi8ixZ#z|W#& z>0<6|>qK^RhNG5TNxH6PWQu1B8VCbJ958hN{+<8WKj=hZ(Xk_nLK2X0Q7kx{qlg6X z4@JNM42&|4rYYIkmg4;pLn5KA7ySK9K|UWqp%!pT0XGlD>jOh5lYrL-a#{Uyl`Ut% z5K6XIR+hlw1bLW@;Lx(VIGLIQ=NGW@oFGq>y`>d!Xi<{Rw$?VlPX;z9H8>y$p#c{M zX5Jh8vwo^29QnjS;07!u0auRR@_#JJ{Q0#!9EAn! zTrsERchv}vMJxP5C}l(@IW6Ci8v@9E)Jky0Pyhx% z<6zKiAgZAlq5=m7hXNjSOaQWnBZB%M#uETrFpGHv&>3`0h5{b61u)eG{ynSu%I!e42K3%goWd=fNGEiepyOBL;zF{2v}S;Vr=2~Xi4BndzJhiXAg#V!T#3?o3jfB+E(l|yB?r4GR2K&_Ym zLDE1G1V|WG1r$;o2l@g7FGkbS}ja0aChf^pcNW4-Q^V=&}P)!C)mRDRVdw4NA%qj;1J87*NLOrJczZj;35E zfK)QjbCi0H;C)|AIf1JLOR1kb0OhO#t{<$}2q?xq7}Oca&SG_-W0wY9+(f{_Q@RFs zEfxpeoqUrPL3R*cJ%rD)AX9_qiX8!#k(^yP-U84G8gODzpS4bb2C0H>De7~ z`A*wd9Fs$j?_}OZ*;-s*+QjxTUIn_s-DHyG5r#NJDZ2#8=CtM!p93#WNpH)`=*T5F zcR!Psc+{bF?8v--i%rRQp=XFn)Uh1_>_qxMtq!0${+O)zKb(*?A6GKfz!|CAl7yw! z4=l)NNACpi0TW9~VKP}L(d0agbkdNqM3Osiu0!{V_gh}pNG+F_N`@GL9{=mdzdl0N zT`_Dd%_khX>3nxfFL%fqgspLYL1V$j_cwoUd%Tj}w}^ab#2X1kN)N$b z>F5MEr6)%}ly!I*4U@2HrxszY7_eZ|6n|8yrW48;jef_Q;kMH|SE?~h;;n^dOD=z0 z_tS$f;^w%W%8vI~9Ef>9=53x!C~5H{a@?LB493%GhV{H1n`N`zgV3quIg;q(${O=> z+sHQ7SnD&K{T2&f=zg3(vU3d%u5~m5I`;1z)X^Z_mxJK^qt~8HTy6Ex(SKPr?`Wr} zlf9=Z7}bFEei%{lM#Gr(WI$8dv71Ix8K=dJj>LHcRfdEK$dvZRj`Niolf??fdG<{* zpcD5fWj(&po-(I!Ie;$Wz(fi|tvB(5i@7@E#t&ON=cghs(ZwSN6Oxphio(Rz6i;nus@sySX-n{k_r zv=>p`^1fvD7z&qv`VR7=Ah=+-Iw=^8Z?k}Igfvf4(&KirrDE?os6Dr(=Tpg zFbwu!{Y?C(Iycf~&le~`>Y5CGvjcM<&Dvt3K{QE;9o#qY;;<{rO`eMK7P9Y-m@cRN zCGUM2XAb@#zc}*PwaS@3RqldbCVpn6PK6NWu2_DizE<)6FbhvvrLd zwtsi7+KkP#+gW0`^Xe@#rLYV3F4A|XyY_#)Pwc8=tF`}_*p*U*!R+BX*K<4Y#Ctin z;=5B$rOC3Gx2X zq~4ex8Fl%JA;Tt={|*u9xl4R5gzsEFsoj-_JtLdC7p{!#)YYET$r3YDFb?8sO<|dk zPO)P>TzKabQ6w$Sz5tC=xT-qV$_C|MYB&QnG-)J5v z9`14@D2h%FWO;x3YGV*<640(OW9WW7K;}wzoZEP$g6Atu$(_x-ufe!k zzsf$+nfTr`JI9m1>G8Lt32Sh0jWZ*#H~#ukD2Lr+se#i|FX+i{gu!|a3Pzp(yb!h- zecEJ%dEB3dq0~cV2ThFNF%CmIb@N`Wi*k2fIg2*+(rj`PI@8mpmvH&%DcehG`og`} zSv?k*tpy{pJlf3@XBo`r`K8vNa^+$0=N<%3Eq~K~0J#L{4L~jrr0~!2Z;)xVLJQ+( z9nGKH>(o3og!@&po$Kg17yS0z`BH(GG#~S*tBxA!^sN2G<8|qIJA0?%T$Y zKiS}2_A=zL@mMi?MyGLJ+bqG+_aSx2Ib*!#i=xxa@tJC~m&B{VzD*B#=pVn+?frt6 zOuc=(>Bsd4y}8eP_KxYB+8)+;w6QF)(n@Y)uPzaSaCX2XKCdy?HK2Ux^!C@OhJq~! zNw4IUVM6UA_!(?LwSCo0n*~VaLNBF{@@$L~CSmFnX_FX50*z1lPG3ec9NniCo;a3f z+IjaKb=XXX&eX?E(Me;uuLIvxtykD?^(Se-O#5SZ!~lEeue9NS3WYqPt-|GU=@SQA zrf%E_xeHV%+Z~SmNRvTf#Nj6zcRK`Lq}L4@?U)rPROq|-A--AX{_K`w_EypSeXT5I zqA+T0Hfh5+1jo*f*6FtrSx)N7zK7_`*tx3w)Ykc(n4{xO$H28lC9(YT2|KWYLYsP9{uEkwZSBi{%?_H=c z9as}7C9@CJbao6M?2t+$IKrDQ&|J?hQSe}Wl~sJF{#%n;-t)qthZu^pcR*kZ2p z0ls%BLg@O6HSL+6*_tf?Dy(rbsV51ni}CtH`E$?)JyX<`1GMxKkj zh8+AUN;T;F^>4RdZW%AIGV-}>TXWIfKVff|=mSF<9j=$+A1=(Mc!(C!2HlCfd;!K| z5WK}l{?oO6Vcws&=^4V?KX9u#g$8R~I_u`!qsmD15^>YsF^$=cVtd=erQ9p;D-d~w zc>gwex__>%@Y;{(p@eFQJ)=fP4x2xb;rbBWAc7wCG@K?+bv^K7zp#KF*7u8mC z-XKq3*}8Yy-zu};Ja+H-p@nDk59w%`LgdItUSA7MhKn4kdUGkqrmi1@I@eEg-e>PK z2g4ybJ-C{=7B0o@ljk(=^OILLHb-?8Fk0nq`W7nO`}Y1T6sL9BGs}+6W5)T?aq>;M zN9zUHw(jpV<+o;*jYEr-xJVP=BPxcmrV%8A?HJZ@Y=A0Xz>TZdTd3t(LvN~t;hov% z(-9ZRf%;)ts4*8&hvTi9;&yFghPjhN(Y(ed#e}vZa=E8x@Ef*j+vZNnsn&caZJKc@ zit2aQ7-`w^bkOT@(`2ol$^EO(NOJhj*ZoyNgY%*JBO_nFZ8&aX6tsqbuVIIy!4dZF z0PKAVVPvV=C)@iY*e7V#js}+86C0`|zatNQ_u>*}VGy z0Oka^KNhS67QYBl0Z0l)gup|9IiUgMg8`Qsps);KVL*Lwtpea6AO_e41Y!M`09!7D z>Ox@=hFU~AaR5xAU}Si3T@)`4p=tnY0%#>3U@ABWc!UB_?`R?p1_A)@01ycQlt9!8 zK>7#}j06FhPz;p=kOdDuA0dGDP@hC`fLn<#=0P!KG7xGCCIk@zO$rKlaAg5%ivts1 z>;u#V<5JQ9!a)OQ3ko;_krERD2#1F-H-MKd)`6fy6a_2B5TSNh0Et038U&4jxj+E| zBq6+SiD?wb3mO|@Acz4&VbGX3(B?Pb5F!u-$qzIV1Qb(TAQV^|ni|x_VspxxK+Qpb zf!o5Asr~F~F#!aAu?v8ULCq*s0m`A}gv4NZxt3!53G_ta>ymIly@O{MrEIAKBAy6e4v-)$$^aw>zXuAhfm}c`u_9390g{R3K#`Rt znSkVBWq|0SJcB+U*;o#ctl*J2u!c)A0Lj6sKm;f-MJ~W@2)-ct8J1;YIfj@5)*VKP zq4bJCS?k4esLsy-$py5kQ2Z0#q$uKLU8e-6z<~fuR0as~;vNB9P)n%dqDBQEI%V&m zXj=e3FFrsD7)!wxA@vFvL&=5|H9)Z_X^LK60yY=bHZYnfC6FekSXE$Lih>7MFQ9@I z8Vo!y)2IS4^QG=pAn{sNbtRA{k6P(!`}sW?LQx(xLYSMvQP3C&dGvp*2KJ0!)iD3D z8d3~wGoT_A%gC%Cd<7`=Yc?xTjbE8759*rGD@d6aLKDdR68km5#20!0s}-f%GG`Wd zs^!n<5I;9}6-Q?WU^*vD&7AM=wM9|zs4$B5d_M!?Mco2R z(y@-39jReH=0_C2|7j*sIP*Vu_A7*I>C4OiL2JMA6t*b%i(i`}G;A#(4?57bkPe}I zeCY~2$COun=_?o9N+qEPKA?my3+x8QtCPJW+1}O+?uHQ~ilIf|He|Asiv$AU@UuwF z(b-xUh5(Lz3s-aSJ?6?*P8L>hGgEUrFv?O7peChzTSt--^d;+VB?&YVjYFc*z{iEb zV~Kl_;zCHIknjqX;qGi{1p`|p4hB53OFv+W;DF+`g8$5e{PWNcPFl)CVTk}nT9t<( zLi+LddYCop0Y@Ydrr+yfp>J3J$OHb)-}Ar|a7?bwgPfSF^3X`I6|KrcqLFLVgOnQd zTQN5@O2DEMEk8Cl7Imw!0J5kdDH4VJOOgg{oYm_3%2i7d1xdgKEKyP zqrl~m)p?jT+CpaJ>U!dTtQ!)822S=>^)Qfwcy%5T+zeQi2T1;-9u^DEL96Q#Am<|V z`^je{4hdcJSe1uaqaI`yudYV`p0rhYcqA}GR_9^=><>6xSGPq07OZXyq!ii#q2EuM zQQ!jbAN4ShdG>ofq&Rqo{>T#trqHT-L?mS4K);`3Ai=x($9%zSw7MQRWB$Px;8%y7 z!oT$guHyV5f56O!wzJ>sfz#^haZxB7~@X6 literal 0 HcmV?d00001 diff --git a/tools/figures/output/base_Precision.pdf b/tools/figures/output/base_Precision.pdf new file mode 100644 index 0000000000000000000000000000000000000000..dc165628470de81ee4d4f74ce669a92ad1493c0b GIT binary patch literal 16251 zcmd^m2{hHw_itvdp-hzwT?xr>=Sh)y2$@Av=ApQBT@oo0A(S~nAt8jwP?8WSg+h^( zD9VsTl*&8bi<10)_1=1Kt+&?yTE{ux^F6~p`|R)D`+WBID{i2sA&Zj3!o>4?;CVM- zC^!=CZhHW>eLEasa>T;{j!?Cs+PJzq!4U>FP7a=MG-zN1S5$;KknN#{#3dax+{sip zhSmUKu5Ww5!Hx-2e$H3%Tk-t_`iiYJ^jPo^@N><+F_ zP?x1lfo?ByQP+X&M0JLv7G~CPaRuv!BQ#tAGt?aH-0dA8@;#{(2OBq-PfCHgjys>` zn(CIqAu%E6!=|^-5ZCCaMcWp|sotX_owl>rwapx#{xmhByZfP{V=Ai!^M~yBR|`d~ z&M5d@|1$X5_tWRsJ>SO3#8~~OXBd+7XKtY$9j*O9P<=~(n-El&-J*KYTD6&b`T|mTJ!9{rjEv5qs9g%BF1zP7!_3D=$~y zw)eQmqr#AqjZJ4UyEke|T75@-b4-i&Y0n~#EPQ;C zuYkSd6rc3T3$;P3KfbChBkomU7um4y&7CgGT_j4suKw%3)gQ;>J$QCa(dzryy9{hE zuT`1bEMn87x{VOyFfrQ_p9g&kB0ZdVL!28y-!ZZJcg7 zhCZ{Oa9pAKc(wavN=>LZ`bt9U#E!1gdafJ7>jv5c!t$LoxwK=4!>>R3N=Z?UEHkoi z3=gy(=GYqK@-fAB+$CO_xMn7a?$EWQqZhAi+2WDPfC`dX&-l5cF!mKc59&E<`E0&6 ztIr8FHQ|i2yrvsnQq{~w4a29dCalM?(dnPvk^`^DJsq&Jkc^IswZ5`e`^~jStV~y9 zt;5%<#~dKL6z{SXmnNB5TG^)ZoJhSGz~GY@xM^#xuSR`*>A-7OZ_(QL@Sp%mtpUbs zq$3T3UR+iA>Z57vMm4aVdmA}eM+=8K^V$%r&ksM>G)cSb<#_vvuRX5VW>TScZnfa2 zyh?&=@A{X%S)A3I!gDvj-2Qk+K<9Jzk`!K(4>CCda&J{0>PYH;TLQl*WcOtW~X@65g!#HH#w7132f>8o!=-G*g1u$Qo$xZLp7l_ z`y%ZqonP)%jx?e)+T*YBHVzaN_HVw%tEom;R&p}hJM+?dJ)6DjNE;fGFzx)^zEPdE z>#f;EO@+6bD#)JEu%vgDnD|kObrPfuKgem+4x5`pHZ+N*oC@~qtPS0hB=sROVW`v1 zS?P8!o1GHF1JcO~UJW5xw|boiL|B$Xy@}A7(?XG%XGLB#97Ij_O0Z<}U!YG-&!q?Yl%?cz_ELA_q*3~j3l25$l&Pbdv(A{biYT6lRCebyxLCsFc)Hv zvP?wIjbrWTmruQu_P`v?3^_T({H4Q*FSI+&R2UANG3;EG!WUL}YiO5NlTw!H9`UeI z`VKkStIccGPu({n-M0@_?7VWMETq`(jMDuvuJGgS{;xF+3#{Zeixcl}2zwpqqc1hu z;2u-9|4xogr9DaMNnpS9?C|WwflSTkbwN=U8>wH1@16T*WT6-pm0O-(^a0W=;`2v&FPS#rHz*{uri~#3(dBZv=>^PFTZ@XUt)z5b$C(YzCgG+(yh9Wo-dz(|zb{gc z+*bau*Z9EQhvECroGJ^N9(i+hS{p$(VIzHfQuNZ(>Ao3NBAzOIoOM09?a9%7Rg82U z#;vao3o@xY$*^zofXJp5Y$q@YEm~SGC?UCZ5n@ zoPI~Wq4nXmtOj9wgQn@(R?>;+i`V*b@b~TPt|szX*PXQNsru#Roye3jl_SBFow<~; zVi)OcQzPGYaBcTEUdtjd_2(kcOO5`by~X4HPMks%iAPnGL)C9rpI89l)%G2++R9@u zH74{Qo39>Ff1-QCRb9J1*HR?Y|J{o!zc!cZk}Z_`BRUja?ePwCOwe?2q)9*Q+pP2;bJ8_wK!Vc8;oa&ld;zQPW+ns){kPKe{InmA%KubX@g6?$zRc zHp*^a1&j1LC@vmctE_1qg^AW?DXbxr-{|$G@0;V(o&2_nVV&Zo6;=y*63cc1kNNvz zF(pVNqydu27USv3`6j1)^`mxyChK@;j{juvfYR2ob;!-oqZ~1rf8-Ij1SCt=(X%g zLa+4Jlmnl=CO9j5^70*1HtsyF7**vlIJoi18fw4_n*e#!%QgXxA^d$4aE3Z$CKT)Q zLB8sm*cnOBwA~6`ZN0}WRZUzrZ?}#R{X%5y^T`z}I4eH$7}v^lJLsG1w|qX;P_fU- z&YRF{-ue`dcN&-_O>N#UQJr`=;T};^*INC0qd98FIP6@sJE5C=UHP(eh)Al)!(0RP z^ZKTysU+cdN}n%!pc@4J3d`*bCSrTBV=>9ArGvsx-}N-utK-9@wtinn>y)YK3S5H@9m~0fL;am=3Wg*yBZ_1AJM5S>-Eism zG`16bC3EhoE-?jKT4?pQO1wnRIFD+#+CY6`M8*yJ52svSiCmemq3hzBVq*((RB86$ zRG{Y$mXeOGDE+uW@i1okz0Azz5La&%&WuxeVIyy;T4ven9X@btCk0;&8hD_6Q`|+c zkYQWisHMG_RlPCyuE!Cr15SD{ggWr)MZ#i0{?i@K+?sdUFrP$Oe?L(UnC$G!p z$+vI96RdVxSmnfgOVR64E$#L6$M^R9_|EQr-g4^-JcLea%Xo;v{(~A)kI^H)KysX% zR%Ak0D}9T$PnA?}Rrj59yPY1p*2hOkXQ26P`)7x`1Kq-kZVt~<8jpQ3_3YA~_2hSY zP|#4?^DROBS!4Y5NPWb8yVC>x;fF#MAJ`_%7WoscUQu0Uj_|!UM8o^ul)Zj8^4>FE zAFD2b*tA=X2PLZA@@~}OEzi2Qju(}!uCwOnxhutfE6nqoJ@b}O(b2mFJg-yFM%QPK zdAwoSdFC*l^qGCgbC5kVB%s+y!2Z?7-AaT)9`(oPE7IDttj~|r7ocC(-?D)#UK8RN zR!r`5-hHTd=F#3Q7oJq!ogh1gS{0*U*9E`t7(6_^e&)t8gO8&f_xyet;?iY^J#Snt}SJ)ru;Js{r2>)QOF?f_2(x3=Q$D?sM z0ru9B87397SE)KTtWrg8ofb#^eYz5k0IazSm##Y3Zst(7$IONtw#t0x1(H?U?+C1^ zP7DKzRdrf1I4VGS#xreKmLsZi4v+pKJCdzJ z#vyOk73CacjIQ$-9=#V@lNr}{8}&4?seHV-9OjA4R?IjTO5IDLM{y2lhIlDf0 zj!!c;{*@wI6jjr}&g2rjwF0XqtNv3#RzNIaj?D91nt^VH0w^i2G< zl8^6KnOx8ka?=l+yA@j) zPxE!nmUH8HFKwdJ9wZ}X z{44Y6^KT+$wY=Qk_Jz+4;%~kCMt@BinYMy#tl;>f{>g0Xlep+nYlc6;j<1qR6KnKS z3bp-m{In!%xURX_`!UyDQ}!fC;-s!paa|ABP_1`A)o(HZ``PA02vg9L8n+>|ve&6pPO*4X z{;vtg*QLqMm{MOA6}uE)mbdd_ml52_ejzwWY@024cb*<>%S+R*7Xt4_U3JLhIp?Xe zzPTw@ztc(MIL?1G+UcG>;4ITALi;|Xi9XyABR1XtJ4s@*SMN`>B=Sj zFYlT^B(#K%@*Lo)7MIy9;Y2tRQYgl4+!MS0HAY^s?(;~qZJXcS8jhz|wkW){jSYVk zb?nyGcsJi9(|Zjv`@c8jZg!stRqd3+_(ciI*sRhi>UmXXFWHi7o9)YUb?TaZ@>o=N zMQXZ)oA$o2eyvy4*ua~+pWE%;Rx9@D19RLu3}0E{rHk8N5wnFXvck4y)PBK`Lv7E} z7dS@rW^_xS_HetUt7*~i=C*jBT6tk?d~cP7sg1w^sbfmD&!R#zjk=;P7o3W^+k@}y zDzVY*vhx(n*<4zDI$UBif8PCZTNfKjSKSTq%&#TYJKZX4a}UOZ%36w>hjf;=*`{kZ zJU$pUmCZ55f6z1!)A+9W;p4yuMwT1J=-8PhcKtYB>vZ(YG@>)_*oGChXaxru^$&uG z!514K1M2uxffFDiz_5brjak8ML}AeXAmeA0xl~2L=BJzxW3_Gj^bBkDY}S3-y}6es zi7gSkm_I!V&U#(c`Rw$BrIy0S^4z3P8AYp;wr=TBMc>y$*p1%1>n9xj@`#ZvTzD_8 z^pIO_Y;#QK;HZTNim{h1l@^eJCc?lFyKLM*SmQsA4+c@0 zk7vlDkOU-L77NbiC=vnuLqWj|7;PRs8!E+x7W9xqA|V3;{(hFAfCr$^0K{TIjDr^P zfFab$AY=hWYJNqC7OP+gH5W%m2M}k00xOnqXj?r!Z0tbH2Y7iNPyoi&!4br9Xhn*P zlQRepfd%WpiQv8g9D}AG9XJ5$<3h zZ~;1jaEd1!K?REfkrc#k7kjERAb38M^ea+DW8|Vo{A;(Cu$sn_zwp)7#>o?`Y@U7p zmBVm^vZoysb0VR^VxVcD3t2RH1VB*DL(7520I>b$i$xI80@!%$bmT+hctpNt?Iu->&Y#_$~v<`ImPYJg49u_PVO3);W2?xdilOf>*aK<6v z(QpC=3vL3A5HYkwq)|nHVIX2aA^^d|1E7RN!hv=KI1vZ>VgO-K4jvZ*z&r6oKs*3Q zp)rBbVL$-|Zh#5UI2g1Vh-xT@sK9~Ap@atm6M*dDNT5B)@dUsY%zPaI3i!*l<8N*kDL9mL%AsWGv?M0x>TbPxw2Cd`-N7Ixl z3}|8Wf>|QN(Xq7~qqXw@&-+5@r_!e#Pe7o3iS7sXf!3S?2M->t1T^cyr1N$I zY$&ZByj8GVn}ajqFWb9_CIE5xYd}Kb>>X~%k@#R%EFKV`q|jtfnYlf8P$_n=y|DFk z?zX_QhNm!fPOqDSo)lv}-1#F8vKER>vg*;t*3{og_TgZs%eINOlY42JV459&H90cv z=`B0u4wpn6r(1X6M8TOIwYc8Q@eLy)DlElYsAPvd>`CU^52kEu&TJk$w(I#}h4lrO z+w%#OuBQs}584g(?U_CK$hq{pb{tCZ@sJwm5fAZ&#r3mOZ|-QU7O3I^Mv)I>RkNx_Km%@blCA{u$J zjvroLWk{AX&^=a$+1$zI7_n0WDw5@&@eAdd0-osl; z#wZNH`;3gjnzAm$-QVGMKMp4E*v2TsQ~AoCSC9DMrnXUpKpgtL$Yt+M{`uSQWy%lP z>pjXBOX_-}|2%0%$fJCJxBag8Myg2jOiJmaAQJzL>HaW0i(X{+(8x5e%T|QZP2oMM z$GmvrU#uTo&y(nMT;Qes+!vM~r}k`GfrBfZhJa@MI|q$)$aj?>1b^?v$K&T)ju@G~ zsG4;@sA`nEwJHqNfb_qAw(^aR70;pIrt*DPEw^7jDrdPT=}73!(~;teWj%>wqGeW8 zxniQQ<^(%Bb*oy=!z*p+Gb#zeEN6F(r?c1klYaCg0`5F5;;iUXKr0t*ek-NYKsFRe z=27Oh8I0!7xnyv`qV}cX*zh!(>41Ag7Q3C$BwHzyikntF{d=WVY9nzb1!tq2HuK)* zf8k^+WIfn+jqa*|iCTi+NMmU3qnwB;wR%N;y{}smbY;yN$i@Quzr{tgAEh#F&Bac{ zO|+Vpw6dFr9pU**`ldBA*y$=7B2Vs|2z#>;a~I9?$XbVFou1gQ`Rcia7xz`+%8EzS zz>V>p9=fH&$L>w(|DZnK^U$k`!kVGv(HjKcsf1U&COwik|4c z5prNy$-AWblSkQw9hjjEw-xquh2zsdJr#O}+TjdHj+=x0UzJ4n#a3mjCaDh2J-yEr zOk!j!jEPM+XU@JFb#kMO!c4TNC*ix`v5c;p7QQh}U32FKHyInw80E;>s#t{zwxn~9 zE2JOfu_(UvktCCugFv{7)xySlxuAt%R6Vh>=Y}5?1`YX+^ zO5S#Q6I5j{6%M?PkTcuR!Gr7(0lg{{xw4w*kNgiaV!^&4@85wm71t z@YC*g&s>dao|wafZt|XJ?>-SWbmCN*_zR{F1&mI1-`?TLIGzDUuK=@4{myYl{*`xK z;;A2<@t)G`1-!zJT)uIc0vFfF^Y@Ns3{WnQJ^G^RF|~i{w(Sej`LMvI`@*ab-y8RQ z!EejBaii(SrN*B8r^lp6Ol@2&bRMvkr`~i_V(T#`L8#3}nEdAz*1GhpyRaj`gZlf9 zL;w?nWL5LKY|VH~;xjNowY|4F*NBrVBwr}(6=qA4B4g@QnbSCALaYu2P9`AP_i7qM zrH*9UbX31*jGVe`^y$OuxU`Y{*CE4n^eP)1S0HT#MEzf2K@7gs00hLbNR}9l5T%6Z zFqLr8by}yVTK3%zTx*BM2-dp^-Nb7zngvDc>>UQzxS#<-J8>hvKReH~iCN~@1onBkQ z*CJb}*}`2e3u83oRWMIN@Na5wnH-YO@z{AGa5rl?pI}vx_Nt%*Gb|!m7`VZpQo~JK zv=PIJ&Cfht4zY$&B*Kngs!Dq)S#T+RQo$nW!OK{ktX5&YT|;zt@88wFXr=w-V$!*t zYB7n!{>3VjAvH1E6*W;!0VDV@w+tqUL6I!+=6XK)qDJ~E$K(#vA?sR^Q&JJT*-LUn zTqi7f<1Yyp^b2$$dmkW;ZPWXTnA+dkI-D4L(~m2x4IMR%wW{+NOL%ygwaxj5;M)7I z?D%5>G_#9&I9-S}p*3myhYt*`n1~Q z-O*DkY{UxgAO?aR{^p3_j21(>{_p_u^ABBct>|yJIdi-C7!Fa5;M z^-F1dHxFUputRV?TT>p-s8NVkPL-Kx$A=_zVL8r;@t#pu?ZBmhLro`k)MaI}H-0HA zE_`!O_>71-zgW6j4f<~V(1wIZV?~aZ#}Zs>&iI^6k?xdjG-omrd_f!ypH4p_dyP5t zR#HMZOxP@J?J?z#7d0dOKd(12hqsLiX?sM38AJzo2X2x&_sH zlmo?M&XMlkpiaJ=X)V6^<5>jZw*1z2mU}Gh9xDot#x=;G-}#zPQa^Pz2C-v5WPBU! z;NZbJM8-6*6m*k9GLj62h_40^iLQe!O1&^yx<7(SQ{qf8a3k$5Vp- zmnT%#P@BS_%zCp$dh(=W_Dj1&=~Dx9Pg(D?Fms+(qV9QpG2#MTW_Q(_=sf4Tml)KE zmrSRQNk4TnA5b!ZYug#%(!D?WPKrD`bdIe#wzG)CF@N>92&tZ-yL~7Dr}C!`?Q2G? z3Kf!+oAUS8i}SA2?XVGZ;@XjfmMisCAixJT%oAp;Xp0prWh_Q4muT<`^1u$K_lFUGjE1)u(MahN|mCTHCXXw-og-Du$Pn zSHG~k9++4rIv57G-69~P{4iIYcdc8CE@H3Xcu=Dud z5^NA0xM@ex-(gtz0u1}d(97{baMYH?0@naV0B#Hl4Gz2jU4@X||Bk@ImLRY|n*j>` zw-8t;s^Z~5adEeY1JD8C>p%gZsXH0q#?-^0M0KV(fU+Za6(0O|@djm2fCd8)4>1o; z;z1`h05SuV3Ibvw6byo1f5TY;cm^@%7bZ#L)FRyQYbTcQatY4*2jc)&{sY9iA`Aat zK&%*W++MPrMJMn71Bev_9I#*~u=sg+48UP95(HQRybBHBCJcC+0!oYU83wcm?{k2U z7(fp834+i5OM)%lgW5uA9s-+3S#bcsq2Y6Q0Mp}%5TXZgE`XHc0gMBEDFGp>1fc2B zBpeJR02~4k6aq(q)DwX85g-T)0#TtHx(~n|JosFL0Qy5?l868^BhJ@BIc+f*fY^cs zL4-hyf)XCQD*^SzfrZbH0osCjX=MN(p#iK0B^-f7%SiwL#6x%=!0P7PK(Hl>hErom zP(Lhy;~)$Vg5|(kpag-T5Z1WBG#cCn%?&XSVYGV-1Ca0}_+o+tVO|CxIru%%cn#zN zl8GgWCJ&HIEGC+)EXV{T4@(n77wsC10m;TzmL(EEplNae3?qI? zT9k>!9AXOCcNi^)%1aV$ujlVWZGI+5E}&h7@}JNuO(!qvIyE3y^F~4)PDAVfj0$33 z3)tqoM#V1Rn>1|;(0rOrp($XXt7+<;3{e01YDiH7CefNKYViexm8PZ_FxGi34%{hP zLrAgH{4HQ9X!;(!0s`GwqR_zgBFP#6zF(L@0|?gq=)amS#io=|OGAD?Kj%WIDueM5 zb~bPnGzUT%{U4hF6Y{GW=07$=$^okb8bZ0E;u6A_penyeTmlsQs$5jr5PF~oJ(2Ug z(F2db0{`{E!sogFs~4@=B4_3e*Wy=fh@WO(!=2&=oQj3eDiJOV)Fb!<)_f6!2Jg}T z)hTm3tITy0&Y?c7ppF>V)=_sDetfVZN%!^la)v`nx-YP;)!UC5{)idhl~psdnN4R` z$$;;ZvyX4+My!9^acxWEMMK%nM_$C7mR1it>gi&>%k7;*KAGp=;d^6ud@ZsN)z~4# zRd~{Gt*XRcDHQXmm+Xk=b#vTl``T+ZW<(yd+oSsZPX~&|ng6+~Um{!!A3pvM+WV#F zvw6Xv|JW0u<6;j*@`37wbO`MWNG}k?rakHlUy$IIDhXBa6(jW0!WQ7xdbqk%U0rP9 z-WWNO99jnMOr?5w$|DeNKX1voQ=Ft=2=Ib|y_X&MC~|2p4|_+rt&QD5Fw4RSpe1d1 z7k9E6^dadMHF-1=jYFc*Al!z*V@c9Tq9hV2DYZmp_)r`iVZd_YU?4cW@B^L+4(M@5 z_|H1%RRQP+Coj~YutaECzt>@i5F3`)teu((iq-NXS$Etqxp(80qpl z$e&nNheiU6ysQq1M*cA#8VWZ4-VTLD0t>dh4tfD=c^%|){9cD7KwK+?GcdxUL{yoM_OTj zNzkcac{?KT)_<=9&gvg>fI^}{*mHS1Ebv^H*Wqw~jz{`qk5FL2%ViC`5%GsiqOhc; z>!ng`TwEO}w2$PCU3?vYet{#5+}$C?NOQup$&SDagA^mp@$#hFP^dIjgu;PG2qrGB IuCD?6UnOZvDF6Tf literal 0 HcmV?d00001 diff --git a/tools/figures/output/base_Recall.pdf b/tools/figures/output/base_Recall.pdf new file mode 100644 index 0000000000000000000000000000000000000000..7164c9edda8441e259a9885ec379e3435f9e99ef GIT binary patch literal 16433 zcmd^m2{hHw_ituhLm3)m=*pA~cb-g{hYXo1lz9l3+)IWi8KO|;u|%1Q5G9dBNFxRMggiR=h9BrfWp=|!f( zF|-B<%boW7oE)fd!rV^{f2x)V)sE@}$Iew4+EJ-aK4dr!`UyjrxY{|okzL^U`KLZ! z4kk`ixD}XIT@#SvjiFFMlv268yuWu(DVj0SCvRkg_BUh6su_fEQqaU=#|RG)JZind}9A zp`b2{nF2jt;G&)r*@fx~N6pWy>E;2}4M%8t0A{E=Ie0lbLF7}YK2CO?F#oI)OMNeX z!rJQ}N}r=h;lu2$!T~XvUCPIVvNqn%c`fX{t~c~WrC;awA4(U^+j5l5_i=4RSFWkR zc-86H)b9R4IaCuE2SMqJ^K?;GW5X?u<#BYDh-?EBj zXUF;vn7WI?j}(8B*s~*#ry70=PrVo^ET>ynmsE^26E5^^^%<+bOn<7%?SfZ`O;{q= zz{Y9@yOULX6Ly!jghfrZq6#`Ls)ps23=VR0O)Gtmo$aXg?p)vGyi-(I?DnOqILv-* z)!xXq%)u8f=%Mzz1t@L8A@&>TE<^p4&o$KnnA$qNi$~+zl_ejgQM~gUat1XviQd0v zc<`hp{2=EVrCVW`Cu3*A<%Blv8M@hOmBL%?lV9OKt)_5LxHkICx45H`8v>^F@Iz^Z z`JLHKA`=Q!Z;@T%d-o!Gw2b!o@4S1LqP&)&+~vgP51Zd_e&7D1s{3V&*VNw4rgg(W?Xs57G$)WNG4lIqFOm6IvjG1STOERqvxM}4gve`x1z=OI) zRuTrY9zJ~L$Mbgv655Qesac=39n!<4Pb$B{996Xu-Q<5n z{SMySFO>8x?*&8j!7N6#vuavmf~%Qi6K%QMruTdtb&C%daBWIg*So^FS1WhR{dykJ zE#GiaC-fKxHMrR#TE*ttC0YONc#_(A=xx6o$~->ih{O@R?!j_3P*#3FgO- ztA?A)+PNf_wOvtW;AzMc55nDB?Q63i=Z0$Y^bD#*H_UJjTz^eW7c@BG+ZM827%f&9 z>FaDg%lozu%UV)B-JNc+*JVF_aJc8^7R|CTHuJcvTYpv(x5TcN&2%Qp7!*0n_r4O= zh~B%Vcld@&f*(cle)8DikM`me8ETizJ5lZeN&g1#b2~-V>5bw~DZG&S=CnIKQ{e$Q zgZHfDiMB0TpNSN5Ku(^rgGQ>G;}#^WRh+gD@z!Vf!ZZgAN1>Ij46*O=Y__rga4s-kX& z6b{5a_(k`P*Fq{Ubsv4W$xgzy6Zfz_;W;<5PBBq>qnLoeaif7%S?rfBXPo&7+ct0Q zRk+Tx(o5C!UpwRpcaeLGDZXdFc3cjLvumjaLl2t8lT&dynOaPM~sAji) zjQ6bcCFk;zJU&mH#i7@?ZywD(&tu7HJeBzVc%T4bLJV& zWABfJk5wdfNITohqI?zU485Nm+4WBAbTRhOXAg^R!^BK?{nwcnjQPwbv>$txwdNhmhp%lFPb=i;#_dE7RdNkYsu2fRTR3f6_*G6mLJxSD zNni~xA7MW9$f)dbOtr`*nOx?oeGKB^4CnXobm#EyXBmA$=RLla<+3E5EeiCk&@St>c(E;|O`?_R=?c58SRj>c)-dJAG2FyXO&SS=y4hA0;BDO7!M!mC zueBgcJa0bcZ8FW)K=ucPBN7)3HO)vtl~!<@g*r53xRw|yA{H2@zZfapU0v>aB;%^i z>n6k_+Kf}Lsn@hW+>p^EC^T!Co$MqXi9d7cB@X_ki_ODSF7L97wj*_CRaFl%>xybX zgwKv5pZDc%Qu32y-?wurcpqwD;hS8(2=ro0K5s|-((he zLjIIM|Lk5iWV(QB{R`$YD~(qhl}@qBk5|av6Uh-fl$xz5&0cQ3W$l$$b-BEU#El9| zHtBV@9wxmpjDP$v@p{_Lq<+b)jFyO?s3|9_ql%4JKYcy#f3dfAw+UDJ3s3ACL+tsP zgO*gzwpFr?*2#UQ=0e8hd1tomys^LkB-1{aBA0R)M)9_2WuTguk=Krk>(^Y#va^GHk;zb$9LRvET1p;3m>2V&WHq+}GN7oc$qtMI*f;zC(shvF?6srX24_=J#H; zH!KDtLT2fP9L%4uz%R(bShAOR%-`9CsX!Vd4UtUt7z%}nBBC3?t0rrL6iHTbcpx>1fYcAdJrU-jpjJCdILe_dKu83S>b3-x4w~DE!~)P)yK`8$zo>7Ii#h#%tt$`Wp>pm7b%#7>g%mJ7&f!2Ob7;&U}$_uysSEGOZz`-X6OIa-5Xc)=+K2%5)BCP4vP2`&>VTycpJFHZSqQv)w=s0d( z`h|B=>FGA+pk@{QVGXW>12+yn>&@0hKVW!&jG5FO{LwG@vtgUh{d?O+J&nAp5EI1# z9bCIu_NF$P#}`;-G`q{EW>k?!CRW>WN^ByvcV%hb*K}~#tk~MhyeX*OTkxRi*1&0< z*;!Xi%1`@)18h$-vh`%&CwLb%B-X6)vEn_&-gn_@w~ULgRp|`1CT!H!n=C2rx@{m; zt@;)B!^}7JsqbHkJtp_t%zPJkArnp*klL2D?~Cs^M|J;&V&|;QJB};G)q0PNYz|mU zJ-osuKu-3OO+aG^f8PY0kuI4D#rkA~x2`^MT7q)gLecmBz#%I&Q#Vls+Zf@mM8-k? zB9W4lV$%w1-*9Izw~2|p?UhG=w9)<)B`EEi4uCY8kbuv zQQJSl(&D`cedNn3`L2;dr-ben8ERziG&4Iz5`3-v<%~DFnLnts%E53vaRB>1A!ChH z*sNyM}4di)jYp@=uhel_KF`9dGOr-VQHcvQc? z9mmstAV`i-V4Z!Pg*w#T5gc{pWOa*)^9k2PzV=R^$mg@C2Hpk*|6J8MX=b(p*Pye= zQm*0tMdKKe$c!lVw?AMZwsdbRf1F-@gtubWOU*5zL|YrJalaZb-v5zXtxtXER%%S{ zHTn-n-JT1bAGf1>%sIJwb(phiyElg-Jy(RJRN~djPn(nuV5Z(kPv=K^_^EQ_9=#Af zHcZtvFWBhxflDVn;!N03n@)|G8-FQ-{DpT`jv_Xv{E7NM* z3Os}ka7%cI!Ty69(n!!JKSi>KPbo1WY?Z&qJD!rz=+p?D^}LajxZdAiKzFG9WY-s` z#(jN)N}f)SvRXpEno%C>%uv?3w3Reh_J2>&c+`@7Id&)Fj>GYxm&f);DYe5g;7XSS*xxE;47ytEADHg2-kSZex@ny39A#6Ef?ekSv3=yg z)W+#+Ato0-x#qN%g<~K2OzYf9EMech8yTOswXSVnl84-7jc*YHLG|S-F~Tgd%m{ix zjs6G8hOymzd0ORuHB32=$TM!mhfcgOGuQm`aQc(@`s#HT`e!FTt9WG}qDvrzyhb zP7CI!)eo4B*zHw#Gx;*qx^D8Vt!BT7T_u{dzIg?T79$Arb{>cPyQ*55BBiM0HYT1n zx(elA#8_-vb=84GY=xOuI;NBUVQP(&x(N&@R<$Xmh`7Tl)0EQ&dCsWnSv>ly%vgac z8HcQCEGu+pjBoUQ`|fsB{n@0!8>rsY)~b)~RWJ&&Kq)sZifZjckK!28EMB{=rY(AP z#aE?~nGS6K2uTuJJB-`<+- zbZ+dLuFKhgS-}ZJqyFKQ#Tiy}FaWiu7_pw!_DzM?X|^NS$2VfPN;HKkWaes*>_Io_ z3n6t5L}!iPD$6AcXWL;iUT-sv(u}eZpX!vmT+A=i6?W}g{%pQssA3RvPOapILN*&Z z)~U50?5r{af=#au5?+S-ed_Y_98+}c*JDcb9(Fv-_V7XS>)GJo3eh8%p@%H>)-B9nQQ-?0yMGfz%WNKGeS zs`&I~m1(xNfQxn?YT^d7)b+Y8k*2mQaAzeq6#EZGLF1%6WE9kbvX5-AtpC=hbdT{k z@8g*&E*#IfEp$3V6{(`)S6=2(gj3-@FLWcDZ^}(qz8<_cZRNhfpz-yM&mL7AA)FpP zdh49?u6T2_#j-~1d?z&ZI54)*^4jf;G+~#93$vSECttYlvHtOSU6Q0kP#FJ~0xvY5 zxv;(M@P#CqF88w3IDt}Er1@G?j=R@C)_&jgtkJyw_%Nn;`t{W}-FY_YZQVctV~NIPhu*P9536YWw5e=c}x8=ce{b ziMg_KO7y(N3`V2h@ObpVAA$v=47G5_?FXR!H4e()2(mLcvQ)qMbL@~+UovT zCg1l~?L+!hGbJl{$x(@AQ{ckf&FWjTL-d)IT{4J4u}MsI_lD!w#*IJT1}^ zq#R}c_0Vw%)?=>XpVny^P>bXEPkq#B@AohJae5%u`D$L6oTfLk-;05(31uL6=8Z-gzRsR#@K(ZX`HT{#`d~7^-kH2(#AA zz)l}d82xto9X0UL&)L5ED-XlO!}XNQ8~Zs&8~k2UgC?_Z&s5_#QpQ zA(Cvi?pw;C4X0(M&8W}I%H7KI zypQXQ=P~@HvC1WLK8QD{H z-#$mD`GI@%WC8o+I(M@Rn3mV=cOQhd8Cz`@p<`nfH~4v|!R6qIDMZhOkWDLW(F$%Q z>K_CVgD*Ek9&Y4S1MY_qAHxbxKR6XFbBIttb^R3s(Kw;PsU{3IKkJAHtNs1YkFYk6 zX1wHY&prz$-it9{{@f9f_oA%l(eZIBZN*PjMd_b&%hsfC+uE;&zN3$Dcz3%gNHG3c zfUyi*&>B~{-?J#OJ)tN66k66SI^^Q*2GeFeZfk2xKU_YDdjo9w9= z%|!(N+yzJLEoT#8$?wI?g>Psv8HB1G#c8gjVXmX5vQtXkX`h|BuZbO*BBSc%;Rr3) zheB0zwex`^VF*3D`7#=Z1CKz|L*0qu;N#{^_40wE=0iFrzV=jFgb11l149_td4f30 ze;gkSqB<8~kwGB|NVp6ZpaLio0sM!8qAM`kJoQ6ex7`J7l#`1w*L2IXgRnzz-DZv4TU}O7XUH06{3=<#|I9ArB{K z5LluWecW7JLA(kqSQkzNEsfz26lMx1fpfDRoJ4{n96+Q8SbirM3JJ`Z3-|>ZyMYWy zq6b(87?cc0c!BZ256~Bcc_?rM6|4X~OH~=UT zF+@0?ga`YFM8k;$G#rD(gNFoK!l5y=1{fj~{{x{vkYn&TsKb9su*J`?V4+ZgCRs>0 zFb0?m2`7L<76Ffj6EIlt5O9Qup(P@XDgq1x5d#td$R8f)BoYY++7aMHa1O&IFcz42UNRQ)1<6>*=R|8x zG9an=BZ0YTzfouyBq0llCLG`y0VqvK5LQZH!vW!7gCWUSlwb>zv5?OR#GGJ2<;4ji zZs9p301F8c4j@16aCtomaqBfJMdt zym+*ZK0tQnNtftlOGPGF_30DqZ{5uB zXJ?};u#0z)eP)(oR*-x#BldLfbqAGhw^SU5XJ6=e$%*X^xPh}DH;oCYvXpP7lAU(5 zrCTbvXUVspZ666Scyd5-V@ZBjF~R3?ucBOAm*Jk>GvOVsl|LkU5jChinJatBi0vwZtu4?WMi7;?JK1Oi;F&W>-zP29KapR& z4Qqg3P&0;i1g?pqj^*RydE^AMFx%}24srDwC=f5D<}0#@lV%!V;S_$SF>`H+=<{3R z-eBS+Pj3s#_Fpgk?l#^^B0+HoK4@$l-I|x3bZ5Kgog|o?^L<8X?&{}`Jo?188Xe;p zz9jS;p?tqB!Nm%<&&rKD>UR{2q(6SR^GW)&fOpm2K1YM(7OGJDbXH|Y7-`+Lsh81s z7X8@1(XlBWw`~aH8o}MCLVUTCpKcu4$erqPi0_%>>{phbM|W>ofrBd@kARN-I|q$* z$xX^oFv0rFgO8c_0*uX`*3Njjs~Hz6RRtkR6VtC z(Z5k%r9PHqT5>YZMU>~px~DGo0=6UfFVS7(GgVIs8f%Fv>L`qWt%c;H8D5s)^hbsHQr{}^uwoiEw;dYOQq$O%**!bvBx|piq{!?^gc9^ zdK!5eJU*Q2$+?8VY!g1xcP(lFH9l);19`tT(c`M zq4n`>+Q=3YqiN$pS$kERDE@mnTpty4+_`s^U;jjsKATRdW7!;Wl@XIe=Mb!et}Yx- z7UnTv_Z8sWeXilEyM5z}-S-t@`^ht!&Q9Xt948!A`!^nrg1@D|v{5<R@=s z(aN#gwzr7i==c$-q@LB(T=QG!+n-kq_xKUiWX7Kt27mhIYMyEvabI`R(tmG+Qd&{E z-}?mBz(IZaE$u?Eb7*-PsmnpP-jGATZErLF_@i}dTCjA@-S5^}D{ydyb0hFK{`yjA z#NOnphqE%ynyRdZ!TNTJCm#JW8@m>L&~}9LeK-@_l>p7nOv&PV)>*RbaOgKYp?rPN zN2ax(X^pqUp}tPjtdxfb+~Rf2B>T^C2h4K1h@UJBxbJXkip^n0L}3Lg7as_J*`R;G z-~f6F5FCJB8svyfi!dnNb4E+zr@R~55Rcm$m}1C!43omnflZs1++kWc>mj<0Io zlY1v`*gqv@Mu)cE5oEpl#-#r%UOxBQwbq~KTKbE7L!`#c?A&(gwymx@RpYF@y5EEZ z!8n^?a$ik4o_@TLB~BQd}PAz9Uf@@*MI#J>O&)G%;^W33qZs>D-8Yr)m&l4MMy z8uMv(=}4RXp%W=cHft@zxKm?i?YisUFvd>i8-M<=Ch7E8@r%f}bo8p5oR^aZ+_YuZ z@n3L448GD3#OAR`mITd6<&^kn)nmdNv{&6~e|2Q_@Znv|%A0bDEEmO(N#uCV0qqTig6~OQRnGgkv*Yvm^T+Fe_U5_1P7i7OQWCN2fo@%w*ps}* zA*;<(V&$p>C*EbW8#hg@ui(8WQ>t~3t4apOXvCvvnT}YurR(0rs9d4+o7J2`Zl;+bt1ApK|x6i)p~dgAMKgTBt9aPC)R$8SFWsu zzScRT+icXfLFlMtj0Iaop^(S86;JXx!IGDJkC6jyNE7>u){mvC6Cpq!h-jaVO5e|itGY0qaK5@&ss!>J5ub_ZAS zg21@w`p(l8d3Mb$Lc1qFuQ7dn@aPH~v4T5@fpCbwIU+dYh3IlH{4n{+hsSX3`0qD3 ziXQVa?57$h$?rDxYdacYdU=fcwKBPnn=o|1DdN`Z)(hP4j3c!RYt4nbKcu5et8gxi zw~g}}hRzM`Z#}ZTF|UBFpLX-IUVtzRcMwB*m@%gnxLJ zRF6!HC6h7#Q{ua0Q#k=Lmzbljr>7i)37SW*4^jDaMk_Y>%SKa6`2BYRI^HqShVh5} zLi@DYNkP(nMw=%vTTwitef%2zDxqQtX{4rW)QM-)o#khKK8hjSklXg!YWFUO2TJ_! zlA5K_uLCV7sGoaU!q~7Ma=(vsvvXseViTHKO8Q8Vx#>pc{#9y;9CuMI!k;-$H}b9C z@lApE#tAVV!KB;II5~0-?G-v3lZ79kAI{zGxsHYNMMdMD>d>wy1p-f5El!83uLa8Pm}asa{XZA!Spzj)NgC$L~|%gwUh?X{*~4d&=0I zi`RUQk?bFB8btBARP{P_tsS!|RZLfDEw;WT#qp6g&#;f}ee8H@MfoykL-phlX-7NW)|GM;M$1ID-hmwSB z_RC6aKokp1OyU_f7`YXXD{I&PAg`IETuOZAuRGGQ{^84@yRG95rnXI)y<}y4_lsw> zQ7>m=N=HV%erMQgYZbMEfUo3fgVXch1>A_l32e>S)ViO_qA>{now-}@{=9KF=DxYm zP3y?YB;oXN9SeZ`g_~=VMkIbWhjTQUa3WtB-BScM}^&VOS?*!p+KkB zd%w^~ZG%u}SDwk%vVKOTV^!oePaQ6Yrd|;qiH6&6<&##qTcp9W-t(TsIoyps$N2|R z-#xu;lGmKFF_f9koG8#c+WRqCai{ow>N<4S1zg+4)jS$o6u!pgck{g6jeUUT?rP>L z-p5cY{{v~TY8B7^HhRbG9Z3~Kx4g+mxa&cvde~EiWGuFw+t3s7()?&8-qfF zqc4D4A?WwNL$a_%kSqv`0t)`OAXzBP;_c+)=H&y$p^rvUSt3yQxAX=)z!xd zl%2t4eel1VA1G4*JPhDG#2ieC2c6Ub>J0EI2$_ZOFzCep2jU7a35YSjPv$gEEdUR{ zcVZDQ7a^|87zeoWAHdcXS@{0~Y{dX_FdgK-ek)08nTW4h9l{5CI$t!K6Uy2|)S?5SRtQsZb6*2T%_ld{9CF z{h={QM1Ycl z69@YI0bfD{q9FN!W`f{qS}=wNu0u%4G4ql4ct+Ozoq zNO%%_AwhyLCj*ci{F!LH266$(#G*u#2S_Fs5=~a-Wdf3i#R;N|_8W`=$;Lv0WCf4J zfjyj;0Z0y(Boe@)X>tJ!BYsg@kcovHVhY%I7%hj&ixO?G=bl4tekDjQpk0OXUmz<@ zCokwabvO-_1Td-wz^UiV1qg;M0IIaW4}kJ%W`w3~fqI_%0V!ZC4dH|V^$Hk6tA-Rc zKn-Z+c`Z&m>VOGoYI+`Hozvn#8fgt7#ZL3LfTf`6d)(aoixe7YxInTd0P^Q&(}bjr zHpcI!i+v6i)Z!T6uMf@;YAT=;!od!Xf|Mpg1^pkJ0Tc4O8RkDWL&^fH0~$iPlF}kg zya*}#P2wV`<#**`^gthaBjhuer^A<6c(2pjfqJ~0 zH@~ZA)ciIk9B zq%!<{oSb36a^hhiOg#SwJQEzy?_`QU?~kbY2j22R9{_wF8$Z{;We`(U3`i{(i3) zi(A$gg@q#Of3`z{c=6Klz-<%A*!?jc5|0O=|0Q)q=tGqyb!a4H1OM!c28&x(hg&`# z8V_0DKl@^kPz-WO9fSidt;2&*;gUKm61dSz>(I;jjD;fIOWQ%-$^fz9aat&opp({gVxbwYkA#>);Gnt2QiA!jfdCDA0NfD`x5eb=6gi1;w zr3@jWA{yR(F20HTd+Yx`|Nry+pXcS-&N=(+VePfn-h1u!S&oRFikcKk8VeJ--U~0N zf}!9@xR?Dwn5--uVHD`?1V<>_k?lOZT;K>jI~ONkI2sf%fXmCnoJfvPL75d5)VxS! zIEGRHVYc7?ppyd`PFQ?Y4Iu9`B-@dl;Mm0sJv%bl$%h2TL60zmp{t#v8_5NZUwZ4~ z%LcXxu4u!_SF=(7LQ3grCfo>4WV1S?%UpR4*OhppO z3p}BqDl3@+y^xxs=Wm&} zlK3jv`({l)PtzBPmqaz~^FPGLbjX`Wm)4QZ(N3&nhw|5^6!wK#-xqz#+%6Vux3m65bD!-7 z&c{Ru-Mf0^OQeRjE*^HJeBtvjvU;0zO#}BCGhZU}ILoIp(XF4Db4qStBr~UMn553Z z(OP3ZB37)IzQtRvqYp@0w|yV?19dFxtF%Y*RyY!JD@!+cNnJ`*l86&O6L`em{fxdhXNgrw?DI8lFu#r`a+in}Q-M zODz~H;`V!Zx?E>ie`DrBP?8MS-5;e-ANJITIh=aUfAh+@t4C1% zPPZT2`)*BJ#(grGx8Na{|Hfo8ha#-mB3#cdq>gu^iGBQ`U}XJmmzqS9pe;;|ou@pY zrtvWzE&Tm(0>MdZt01T+%UHD1Cpw z*W|C7Kr7TpJdIIV+=-)fUniD8`eVXE!|a_hXiGX^pPuCj@Z z;zRhfr{zOz6P)ny7H--iVKoiG!<#e6+k$7>-=)=|y#i_KHJ#lym*9HL&l)mZ(l`tk zR!MWnO-yp5QSPl%V7?ZF;Kb64^ffCmPtR{@B%b|2AC@Bd90{+t*48*04k{oUUE@7` z+vJ;XM6}5@Je2NeJ>D*q&m*wMDBvrR#5Vry3?lPK{wp`x3H|s$=8qvtbzG0jS>jdY8|GmwUf6-^;iark79tcRqpxsBV-o*zbU<#*>L z`$^Y_()KI#M_rle5Aez7(}<}NPDoc?cZaTqKPkKs zH1)=xG2W;2S)ebE34g#3yXv_@{sTV?5AH}5d-a}y?G5|)aM$wWr@ZstSzMHR`70X^ypqo>C(H{R% zLg?Eg6`HiHvRYZAM2^id*=OIjHoShn^~tyiw+$ng-#2w0dBiPUr_PgPu7ptE?VSBt z_tRU8G_zVfr=~Q1e(p(*G4;V1V5@gOV->Uf%u(s+^Q=zj+2uM16A5MXF8HHM*X<3z zybKOwk_&U;4}HWl__QWZKOl9uO6?^8}w{tBZ&D^XA3_NP|2bEJwkrN^dT@ZY!X z*ohIEqgvaB1hU56g_s+Q5p?$3)dQ!}S~<}j7pBF1@yR{#_Msb@20wk?hKyEu>M9?8 zGo>Tzc*V2d_6n_>=J76vfV#Zs%`8tamj+|AH{^N>*X7^qG+GDyd}1_$F^Z4@v*>kJ z6ptuLiW7V`^2zP8gkpC}xX}Bn-wt`lALI&`6 z^uo)lwBMeWP262dOyz_va6 z?J8vhboB}S{`@P>Nwj`?~w-H8kKa!C7J?VCnrs_I)z_#d9SC@d~|Zbmq!{{1(CNZIZj$+{er zLl&$*b)RP0##DWwC~u;Oka)Ih_rNS+_k8>u%o!i-lgpOey1yLQX5}L^Z zoat8F;_mRw8~GmxgV6IFtoZ}iEbKiT$x3}BF~cq2o<6>pN0!$(6jk3M(L9y^sQPV2 z`Ne5Fxhmhos`*~I=7wHVBNxu9-!ja;E!%FG=9E#zV(E2Yue969t51XlkMln(U;AlV zY^Nm*U1KrfZ1CP>`0=19E`KO2_4xaHx(yCWOME45EAV> z&RudjXSK&YsCU4k&SUZk+p}wW7#+si>G5mv>IPF5vL?k*RxPv~7bg#xCL*L?c>od<7hLt??9?Q4%{(hnoHS$e#x=~GmNoHuB?oT}rJf0sYo4q*9;i?pQE>CP;+73>=kkt!i&1Ph7$4)g z^6ch^)7RO`T)V1p-e1ab%||xIZrX|{G;1!9AK>!JmrFUY?c_)Ht!93952bBBttZe9 ze`1yxFL|q-{h(Cn!d>2<>#f55CLrr_$s}5`Gb#2HWK*FrDD<-NNin>jOiQ@+avpl6 z1-N8elG3~5se)o2lcVF2shfW_W3~Cp?eW)YU0SO3*(y(*;$mL}L<+iJ zw)!Lz-9cvqVg)YN18N3N(x~nM5TkZ3kWwoqXCJWZ$|d&LQ3+HPIq} z?2Q%4^7(82VH1aX-%b8G1r2Q^cm8H;p4_&Z$M>811niI8k3f5R?J}AiZ%KE*!9ke1 z_PO>+_=`s`(_=JL+)E;Lj*}LiYR{(}d-w1}>Cd8$dGpsNws#nl)^ZjNS*EM3RWu6o zcZSL7?J1!_Z5%rWSVC`Fzv#rlgI!9dR^*6(u@A+I4CX%pIdm^BG|MRiE zO>#D;g)(DL9TaHXG81h9PezGtpGixpD;E|XOisW@n=7+>D{6mn zk6T~dp3A14L4`Zed1e)N2-v?f2UChPLFyG@H?ZfXrox8QQeom@RtoYFDXnxIt?Ns* zrP@wO3Ng_Lr$*)rs&j`rdDt?hs1`Aw#JA#!OXM(B8~5a?MhC z6L5o1Rj9qDKB9Wgf^~v5D#9T+PKfJIUJTlk?vgEJLlWsPGXwdamLl z4*7RBV(AHE#N~f&~&Zb7R$3Wi_QY-Ma3!42uYro}p9js-9*3L8ow!wAkyA zwCg(W8p;H#)6R~PtJEiJrPK(r&HIewJBh&wStWxw_=heg52M{Vw_Ma6$@?oRx{>G0 zl>$S3H1mBv7Q0F8`7-`}FRQHgiFyX!FH~rQ4(zLF!=UhgXNz2-R@^%zv+^u4RJ!xz zm(sx*qsyr1j0LaZ$LSwb6%@O}IWeZMiZmCd%le-Nu`Ya-aIowut2G^$(Qhkh*}1KF z+)uTayZ`8`qlFBnA5``C_&t?P5+H1Pu9d|5`37Cs`hri`*NCUScLn0}^}u#EbA|(F z8%z?eT4pr6?>U=MK^p$dXv?~DE3vKXyxL(=sh-&3@i{_!&I8>SWWNY)moXAak`O!P)=M4KApWHIL z&g{8yAG=LM$<@mv9bd>>8l=~^tTPc7N{pgS?^YRAMZj}Ur8b^3A0Akkrz@=QJ5>Gr zAQz2FvsD-KM~1E^{g{RB>D=>g2L?krM)+Sp+H|(gn#?kV4(B@GU}^5O-)!jcec4w{ z0XBHVovwS~cU*IMmPO&S$RC(8UHYU4j!(_PIfWWnByY)S+ zUl<7uY+%OM+C6?^)g2?e9vbxFDEl!hc8tI`~ zr{A?|+!1lxR!Fm_V8YT-$hy&xP4`JmCn+@?_v*aqRmHQnSYv(|D|VInsy+*S6Y{C8 zv%SgnW|56-N?Pg8V2+wE9VM8Os7D4<+uz}ihcUQU5C1@h4Ka7d-3v^s@Vn@EzS!9M z!sO!s77n-E&qME{Qmr-3t@9GUpVxV&T-NV<4By}T^9PgHIm=yCcu2+QL}C76)zmqo z$W9};k#@2=c;~v{+mk&;2R9AbbJx6lXb&dGh`edL?(AO3ilW-dOq&#?0!L(|m;s9>0XuQ>uJ%-n~UX&wNyrI&Spl zdP6a;JMVl3veVzY{-v#~zeLq^?<@bRksow%LYOC121Lau#9;ryC{#_|98M`OX4%cb>3 zb6M~ARMi*vQ*OoXM?7?hem59(C_=u~K7FC+n2hxR*=;V6^Q}G_{^ou8+tKk4zA5{$ zs-lP;rYh_xL5+^l38x`;o+0NfMeAD|uNQbJBz-R|II?hNh;aK^zZh7zzr2{(a`~h8 zdj`!jhw;R(Oe4O-Oqat>wi)m`4s177AQZBzJ~?;$LRXIMxv32|(61Zs+QH>Z_?XA! zGv2tG9_pWKw-USfw6bZMzeSJFjNtC;itDyJS=NUz>izn765>5|Nkq;HK6h-`pNvj&K=+l5?rkM<)yP%K& z7pc9Usk>?PSK;~;QMnh2IrhlYMd-LjaVE$kmqKTNCozU$)IxuEQPA-Kxt?7u z7aR|q=_$v9n*7L0C21aY=EjHXOIKKu5sXjh^_lGzInVKCD0kiC-CW6h6T410nY)<^ zN2nN=|7O-wqv*c_i^3mftphc)7DL5!L=peNqE?O9N(2W<V^@vBtvY#f_aSY@C*QP`^N7C9Q3)O z`zm_7@YTlrU!42$QwHQ2uet?w{yI?bqS3}S z(aA0+V}qM^(Brqph>d!j56Yr1igzhtnIFxZ{@$1&^hJ#DpmBe}o&#Ry*jbwgLbFW4 zHrpr|-Jf0*3>6azh5H8+NmY-sPu*F7T{xLIGJj!$9Q5Mnua|Y@Pfv@6Xe$&q^slIa|EsE7oN!?lw!;^pL&Alg4CQbd7!A+tvGkLgCCRRHch6B zp{IrR*po~PbYNnyjpr|(YKpt*beTQPS7}SzgCwxB!xv?2!UaRLTrY; zNn74xcFQ+>9dEOLe59$4`S~?5xgqrH65YcEsZ%6 zw@uP+omNrrK!fAXj(q#8LF_lblsIO5jC*-I^OC5ihE3$4`ali#)E(1T4yJqRg=Qw{ zlX)3!#(PW2>w?{^<%jD3LepaGGaL zC98qPUdgKB`|0~{KIA5s=%{Lp_#oel`ahbx{?F5q6BR{H&!)&KlW^UGW_9k?IVLevep7U1c(>_(dDiZs_ z>K}ZP=5SuUoP?Idy|-K3Es;~O5Oy>WwrUr%slf7-w_`G{Tu z&)JGn%gP^Aj6eLCrPn5-?u~kX;=3^&WRKghChmdT)xpWi)V&M2?(Qy}lkYVhjFgoq4V+NO(hmd+aD7_&TQ{WejG4O&g7#PMZV+6 zd7+r(@pzMj>RmUbT3Ja^H<=%|&}Zh;tTYP}7P4Ko?EL2c)dBE-X~R;Fo_Ot@GAKb+%O`i=Xc1hsfRvQ=2RKIRX@KT?%KVQ-3Kp~_Sx0J zFZgnJ-@e`7YFdAtYir#p&7G$cvl>I24?jFo{d0ZipfPSOlTa&n$i)a4N>~d*$*0H%TQI>QklyvP-fa08JvpeYY998i@6 zE?}Vpu`<4J1R0L-gCqRmh{Ir*AYKMx>gGsx1)(>K@kPJmktnQP_GbK+-wNhanDiG; zd)T@7g0U_#^nc59I6~3a0g5alqQRJ;ZlDJ#G<5h^w)0kZa&vJdLkeNGd<%F1BC6aR z6iF@~PCx(%L$Z^n33Q`l7eE;n9EAlI%!)BB8`hR^JQDf;MuPpHwu1q(#)1PD4!~?O zz_*4c;$bKx5)GFjpy3!K9=s$_5)O@_6u`(J;aC)iqXJh99tTzUpAu~4J1iI|l%P(Q z6ArWiIzz$F2<2teoX;8;Q=;y^h9Tm~G32pB*Z zbOq{305bqD1BeIcBGe|(I}Er%fft|yG!6!h2BI3eLR8>D=TO3fh6!Nu;lLsfUGW6K z7H}4X?g^kV$W;O*Jg5tZJU{XR=z?hcR2S$+y;3!m` zlEC4BQWN6v?;D6fs0kzyPzJq$r~*_jYXYPJ6u|+FhprI2mI+y;3J(nlPvIY=34f3Z zF%`R%KrI*nY6T>SIH(oq9d4-sARMUn@*hYWD2f0n!>R<<5gcd>biAY)%hzSiSiUZ* z)}m%WTJc8$eN&zR;0$TVa-t{)ct-$x1*8bn5)e2b90(ZFj1>vCtQpJKMTJ;Y4CsDk zf{0sw4=KQMf|LVj8VtJr#=Tdbd6eN47~B9_hAH?ixKo3JDG9|Xf)yFak8-C02MR&C z(}JTR^8mcHgQF>nDh#a9=%tOy5e`;2%3CKe%NOsQ;TULl0X1BpZGm#<3J0@m@osTz z03ra7FtD@(u~A-l01a8ZBY`=(cn578V62q)J|OsV@y-`4084j^I}W8g7}yIaCO+uc z59r5Y34b_tX-J2`WL?aoY)L@zpk{wBrogkk#@+od-fm0@et4Dugf9e}MZ0u=J0(wpn%doIje5vxk;Kk;g($Y>N_54Y;U)GBRrLxFf~ zQpkzXDevOU={+7D&dviC*=Hl)Cl7aW*?-JF+S8_HZN9VVLx}V4iV=Mx4+CE3+QkCN zoIdFQ*5{%ZEWfhtt?tYc=)y@e63C{Poy4&Y*A`AZIr2RQJ#bLl+QXpW$aYy972}Z+ zhJBGc9M19#4=OT_YSz12M5G*6rBVDiUH0$@1>OrERMXcdObv z>iS%Fi#FDs)#H3x*>IP6^A`_QWpUqxXCDqUv@^MzjuInl$SKBDqtj5CrM#vZ1ns@!^39@s9*yO&1Bm zBlVXE3P$25JTns z@=x5s<4u7Z6-%3-rbZ9TRWOMXeN^Vx?j)?P8OR@PHult1Uc*R0x8m?P~7)lY5{Hhzm)8xC7=Z8V!g;>jQD2u~-I- zZ>4Q8(Ka=5eZ1X9;d@(wm@ItjcEd~bZY%~|9Li>Qp+d{D`30HB%V;e2A4CNLWa3l= zna7Bsh$-P|A#VY-0x9Ro*EeZ0#Pzg~mt%xG*-q_aMkN;xirmP$d8#&A=@aSg*y(zB zfwBR-J!n$|c{~@Nz_FYEJZ7&Q!6B|r6-9r3l$Wdf=fza3G|KPSnHNjdmNIE3 z8S)rC%^8I#96PmrFnUXfmh~!Y&ZEUZ?}_AZBOa;)-D@6-tuMtPx32^M}E~l8?zPO>nKj$VZEtV+}Mbg z&x{oh%9bqJno+HWxo5YTt$Vhig=XsypX$xntL{6>%yyi+Wv>z&<>4!Lm$qAbvWeK; zz)|lpnc01*1cTWn82qv-?BJM!e`(E(clpJ=n2~HxDl%PpQ2u{1D|GZVqG*uJcZRv< z6$EXPs;?@iD-ZvA{*X14NXuB5n3S4k#&HiHTE7G z1Qh%NP9qrX<$lqm@UOpOH=~c*4zqp?p<^l!RNGFMB5K2B#-QoYt9M4>?i(Me2fcKg zymy{>*=ck>_32@^1Px>H-Ye{ZzgS&FIV=dtQlXNHy@dS-n+1cefHqF- zdiF4G(Pe}yb+d^ykG3R3+lI%f-mo1P?xoPLt8`)6Nu^sVwh{{e4;!{>{2qX)11F#E zC80SX8u<=qv^aj&%i)EGLETd`_~|_E^KP#GiR^bixgXoVDSLl8@TJzik9aQf z)Pslo8y|fz?EQw{lU-Hy;OCY5z1N=~ml!v;b2Ha!Wvs}&6$1mGGtNB11`z6<4W^z;V zh4Jfe!^YNcP}=HDg)}NQ+CLp!@D!5>i)2Vt3sXoVe zRARkjzekSJ8-3sZ!76!5=X7BrYJQ7R^`6r)8gn+Py%wm6pDM^=-mi>as-kouk;VBv zH66BLZL_fFGWTa#w1Wv$Vrzk$WR2NLabx}^b{vDj8$zc95bc{Aw)Sw?CP|6*jZfv1 zR%UWQdfIzjv30hmcQl8i;8l=So|eHuLxt_s%%*>F*yDiN1cd>1h~HMoncCkzbM0Dm z4KSNH?s91(Z3o2>=1;V0x

(G`zFiJ}XkJ)PH6oqs^dcmb;X*L#l9J2U~>{j8>mR z&MY0lwWF)!^T_TzZ_SIROgC0=ZmT}6vF`N2IR=3o3|wzmq4|zI+JGil_=T_Ap^cF~ zqLC-AR9|?#^Tw4+pXJQcTVE$><#h7v=#H#!df23qWv%foD?Lq9CGqUov0|mqVRebJ z^7~K^PLAUvJ+tXVhXry(+8Q}`7v0}b?VRz8yBjpPJjX~ZX zWPdBt(Eie!BD$_2d#B)!?+0;(^$`>LN!AVCA5$M4-uT${=Qi$#0}fn?C-+?~W@mAe zsf(z)U?bCHGs8h-t&4ET%#dVua23586c<_7d7(7NuKB)z#g~~)Mx#f=sR)9Ktpgm* ze^q}BH1(FiO8y}lgS9OK!ml^C3if76P6cc13KP_ePTQIw z{8F}VwnFp7^fo#hieu;ZeH||384p#BEfc>?cj!Uz-iDm3O!vQ)7Z<*7 z;Xfl_#wBz~r4HTHII=ah{bP}{0=C{;oqaW_3R z3dU~|$$eb$Q`WxNV_&xznZX}V@M(C*MCv7+^gs1dotbz<(qDi3XN(w%W8~#F)n3I@ zB8h3lrYiF1*K?i4SwCOI5Ndbt8nv`AcX%SdZ6di@5eg8BQb~5|>@H1w1 ztW#`aGsBIS#IWpieUpF+}aecZ~^ZfaC{)v-R0)v1*e2uF8Q=F%i zzJ?Et;>J525}!kyufBFTD-r(g*Yk}J8R%J}704EEvtlm7B~7c}Clt6gyvCq{U(-v&eu16j}h-3X?la=b*Xsn)U|osx==1% z@xgVgMiCAk?PqpEF06ag(b8qUas>FWn%P;qc#_Fh40{|lR9!Ii+PN$3w2JI8H`QYC zKAa0Vh%@9c(lR&1*{KA@8m&T=llExPs5|{c70kwz3vBT z(-+^8q}Ksj!|mKp2ah~@FkNqC+jQ|3DwZ0A zYyH(}a|;XpX8?&`$DZq~2Fn#4qJ(=v!BtMKjxgu<-n`}+;$~*+ctxFu^aRtw_@o^S zeY4gps5a-?ZoRg5F#;_mo_&%ENmMKrG;p*3l_s1%krf!q%cOiy-Y&03+(@j8bdETE&n@lt z`V86i%FiF?=_~UHbav$!iWT+J%12d@HuX8&I(4>Oa5xffFUBjW_$XhMgWI!X=M`M7 zP4u?@vlD$ShB?itTTaohH<96cKJt7jMQ*?7V=@=Is{q%!g^@#bhwQhw+-DqP7T70f z_O52O>j!D7_xwQWu3N`(sC9$mMa|^Wca2`fpG)aY=YOuhxX&ncEq{Qc=C|7a1_a?N z09_>=eSH;^6@V_py#KIV;PY95kE8s-n=co9s|HN|POIYrjrdLCXz)vrQ1?L(R z0UZ5MXt4AHi1zO=E^Gyi3t$UC!T%|Y3l5wPUY?$Ia40^?+sVhx%W)N;iy$8c&@S24 z#|Z*!!FLnDUpIdz7{V98%m`m6e*gg^7I97l$ql&jL4zJ%q$SuF3QD#2v2$>DB6~PF z1Dq6WZcEQT6l|18fC0<_1)ypG22%xq0el4^z$j?T_5Tz91t-xz-TM@AEIST=i+hE@ zR^Y#Dqyt$0KR~}!<@Emr^os%D#&5KOGCpWH|383!0sFB)N?2g{Q;;bP9!7+KRDi>w z!6JwOp`K8JU_0nN29(F(QHxiA^gyuQ|0cnf-(f&mC@n&Bi|8*7Kur|f5Dy|L@iGw7 z2=F&xJmCTIgoD6XC;@y0O~k=K0)Qp}XCcrQOf3S?2?7MILEtWQh28@=3J*ROLICxl zHi&`tK7*2xUdkKEeU)YET~tMfCqz3~YM8 zi(&r9Vn}JQ^?`!WRbGAt;VXPXze!x-t@)j~(zBrhtrLs<&;b*3NfJ6>;7b~y0}PTy zG5l_TQe#=$7I)U=k39VjWK;9<@dT#TQk0quTnen;;1}35OAc-n1_ex@zqnI8;}uy$ z;LNJCa;k_a4J}otu_x}g)3x9JsGvEdpxuYYk+m{Ci&^j*-?_$O7` zF2Z! z{upVZG+GkwN+x^z?nWRy|9d6v<>Mj_Lx4zCM?VMf#k7^Ryd9n4_Rx`nt$ZVq_$cGCPsTo#HFj&Hx`LXBeO2Az)3mCb46v40=MjMsv?dP<+xnv}cz`y^ zRe6xPw<-@9-F2ZhJ3$|Im4Yx2*!K%Ljr1&6=2pb2a8FpyF9 zXCE>+$o5*52O0W*=Aj7SD?n@VP{`HeK;hT&03`#3sIF=cNd)efRe2~7Rlg<=3j)|z zm&31_gPVLgAFq&wsK+p^=bLyDASlbgs^mfz0kd%Yl>FnsONE-`rN0!>%1K zhOkC%DC}App}Hgxt_rkU(x=jP$$L;1?7p<9p>*aP4Q11~RV^P-#*G)T_i@Bms~ aJSF&&?R?0T?E-8nXe3NTM0LL!?0*68qOz9& literal 0 HcmV?d00001 diff --git a/tools/figures/output/large++_FID.pdf b/tools/figures/output/large++_FID.pdf new file mode 100644 index 0000000000000000000000000000000000000000..ec04806eea95f160c1040cc11d1646d5e934770a GIT binary patch literal 17440 zcmd^n2{e`8*MDV*YbZlR(Uph{cfQvpG7lLtQ-)mgc!`@ri42jcjFA!GF6BM zkrEY|LWQCf?|Cl1Nq+CQ|GWPGwf<|pw{@KJoN=GM_jArZpM9Q$bd^;lP?A`fP(cqo z{}v1dN5b8$ond?T!V!i(9(Hhq5{W``b$5UxbV&|&UT`$1pbwXmgV~X7p^8!~8mPLH zDR2z60>bowwX>ZK1x{FeR5?yjGoX+tc5v)shAxRhvGXLuanK_SVcb2Tj0>;%P^6gB?$O zbV%E)N0s&Vl*uQ@9ZxqrK0JK&wzv7kLRl$2Bi4(unVgOJ@mU4?oPW;D;MD3>z3d%z ztkw2(nJVRm)A!QwuFH$y?pmm>x-~!leCo6NK4$Lm(b4+ae)riL9;Y8HX$J&jiaz;Z zHeV;7k-xVkdly@h`*@avXQSUEHqTMS0*3i6D|h>qK^n3j?u{IG%f!vGSPu(tnb}y2 z!H`+E-NGLyE%WuP8=u*KY_c(%y};KZB1&v(pXut4>3@Cvv(`0|IiL5DAJ`u*(>k9l zaz1>Y>O;zq!)zILEp9xAjjnz2)-sCzx{B(wGM56Y}MK+j7-~ z`bF##+~nm>DEr!Rdv<257}P#~mr;ke@@HYV(vI_A%UQ=_F`m8SY274eH3JJ_PtZ|tu*z_k7157+H72`u+@`gx;s zeG4)f9~nuyNHzE0WAh``Z`PkUf}Qj_GG^2!^_ADegOhdmx~DjMc_wcAPr*9W*|=8o z<_-6S_iekcV^?4DtTNO_mqNPNrB8F;#Dh78K`rWfcFbWzXTe_;uI?Wvi8c+WnCv(` zeaXpaO4}$bph@man0ln1L?Jf2Z~kXvnt;s)<))9q@_YCkhiSdagm-anZXPRZ(9PA4 ze~|PVasOjdyESCsTI1dlWFpTbIw&czq0v=Vo8?GR|JX5d#ti7tsGwqRJ8eq>9XIJJh+>(kn4Hg zuk*t$lQfD`yRq_I)z!H|XC4mlaIbd2k7cn}1bn82v8&_<3dQ zg^JuFb)6iV7uc4&3?FKW3i&?;=LQJuPmZ;ni8`O0UmzBZ_#DplE&81rvT2C!L#SoF zuZUFA)Q@#e_rD9jvhRE~N^<++)ajt+jJZ?aT~w-b&j)t)N^>dc_8FlX!Amy_)E?yx zW?jG~zb|szB*>vVH~-$y%Qmg#(s2%k$1!*GYh=IeXFHbq#kurwAGUD9_>k~y#=3%c6OPEu1JFajsL6iRHrH-qIlRGV0 zj4e<2xxIQQ$jft9LABaO!>_&Iq_(f!YbR^l1Nvv`#jjO(YFxN8WxY=}nK}VK?vSZ)VHRA)Pr#p*$H0UHBChAy-Qc4)wTH3ko_SE*#U_F`kpWae# z$T*+zV88bK)50oFQKZr)XBAC4rfbAbdQ(aByNV$dR~ zI3;F?#5-Ayme=jrNtyf8Xq#>%yn^*zA;a}=R3pQJQN*NS^?CEtl4zF6$Y|cHH#E||-!k~n(aY%2 zS8b?NZ{YkP3B>&&>B0#iY`LLCM~QihYi}D^*kQ+#>EDi-D=9{?32z^N$8}59Jz}ai zEOATIK%#t;j4V6GqLinj;3a)AuI?#G;&N44+Y{8TT%HotE9)cg=M$fe2&)r_3FK3* z%IFcU({cJ-4{vgN=XC}j_F)(mPq(*>Xt>=GDsFo)SFWyPJ-_q57!CJ>?Wa4Y_}mJ$ za|4oVU7lXQJpOvZli5q^<)n1bvHef3+Z=y$CpAR++rZ8iPx>l*p}wTDD}G(P^wMTl z*2NO8_neJxKAvE447X}^BiihJ76czDZMkBXfVV1SOr&@55PWn?^~~1Cy{Dt{-n$ec z^1~A>dq?tWzS}6}J`}sx%^@UF-6%Za*>d*oWU*7>Tk9xSg%<;OE4qsV@+T=uVM-}2 zrICzXO0Nw}nv~7=ghgGC#PwBka-@!o1$}I}U{(CdQ1Nh6rrZOYvn^=j1y~a%N3tHC zdda26)+U%yrMfbt@X0Id8bXF$Fse!@=Ce*rgI1CwJud&|de^98ifM$yM5I^ZpdhM* z*tPY>{Cr|ux9~78w|4)Bw+r�tX=DZpj{5GV`d$3uMrsQBtU7+mCAXK$&K6i{(7@ zN&{@k@w`IuAwP zW+`!MBcOGJnuf(Lbma97Z)CjvBHg^g!LXF)n!0WXOtD7}9Tka*cy^j(f=ZqecQA>b z$~rz_nK`qs;_d?1m_^!-Y&NZwKiP82`a_R0oWHQ+b@g|&3Vk7*5T^5 zeprJd$X;JX5f1qm6yXW_fFfEpRWABJ2?8UVIxP56 zE#*K(MJF=#rlL=X=l(p;i4rFSqX&@8N;AX|$;Sa-%KE1auc9JS=G=##B#)}dD|DXW#F)G) z-aj{4-uuj#b?&Q}jaf(eU6U~>z1GqPySJ5$9aZV!>OJx5MA1f*Q58LzqtEun@e?+6 zX~ywP7t)2ke-Kn8<|W7TRmyhKxaceA7rEDa=>3`$vu?&R>sC zvQ+it7ucun*dHMmUE?t{w8LjJC14F9Ku6eBLO|nCe<1`tBAFJ&{NX3e-)jAb@}HL( zPji+nxGOou7OJbGRi0Ggg?m1+E4@@6Xh@38xV7Qq8K*b=*Ct8pyIH?5G6vZzwtBG0 zZeR-$6^pAZ|I90Q3^Ow-KARot>aEC=XEoU_&LBb!EY$d#ah7L95lr)@*prix)~IR-619$s9Z* zM~ASI|8d^-(r%THxwr_Tr$PetYlI_DRN>H#H+kWmFIySR)_Ljdv{%6PZhsD9z7oM{k&2fdy zhYlgnXNlZtagKMDzM=9xq}Q*uL?KdOV-!7N1D{II(*)h9=R8?zCEnF^=`YB0P6el8 z-WeIIetkCkS(vM8OMcJ7^pt{o+Nt%igrNR4E{q zck;rk6eYRUnH!qdm#x?e!Wf^@>oHp^a9-p|QR=wIv$=}-I(FTzc&?^DDZ)UA%Br^x zMf{76N#&eoEU+V#X5>Pm0~BVxF6(63qpBA0=x-8ZIf`T)vbw%F*M;_cy~l^~2Vu2W zi=Z^f1S~*JR5iLgl-h)*r zwa?CMb4DYc(r2fA zDwh|qoWFhiMtN6T^hM5u^WMVSEjXfc3!e(hiQRsAxO(4uCuUX0HQ4Y!HK%a8RV*}M z@sSPTVz3%5bHB`V8ryw0N@RCq;NFWF>O+?3Ms0qi#jujD(Wlpla_2HxtY?_n{E=jK(dVRLudAu*hZ+LuJ zqbZVAa(ckVlc_}dzH4X#r~Ekcz;pU9v(sX;38iJ9N7osqsc&~se~FsDOD}f&&dyLn zt2MZ@hB<}C{>9R)a!v+Xnw9+0PIH^qet#*~P8-45Jy*eoHH z_NSU3ynXe?wGo4ug4vPE(dSv!y~otk4@%x{FObT+^`cWJ;wIi!vrm22KqSTL+PNIx z0bx!zHGA^$pAS#=`+1*D>&`2Dt#sd>d@h(k*k+*K(vf|tx4z=`sr6)k%QXxf%o=7N z3ilTavTTeDYdx?mj$Jy<&0zI&e5#T`I&5502aJk#y;8S-~r zrTTNWGHI4V*T9@Qu?h{vvjhRjkQj&3FWRCd(g|-&XD+^vl2AYDHu(1J!Vv!U$d3)B z3dqZA=*Al6CF-xX?E^&C4XDi@reLSmiCz}G?-B>sUJ^3_lC8+XWWja4(TJWYh zu3^&E1atm=PiaOj{fzRVBj(2kr0z~{>C-?cUko8zdwjjxoxd*mMZhD&km&uy!ifcc z6Tcfn8^_&{9c#>Q4HFfGllphFL`7GDokoD6dJc9uAa?ls<#CGdi}{6@wKtyy2?uM* zm(=&L4%d12Q~ahAv0q6aL+OH^)w&I%6^@>{#3GnrwB>u^DelV>vqqFR#U)N9+0r&g znZ&p4XG#kR5|ptfo8)W5MEZ=rrv)}fU$?u;e!)v|YwN=})oUMS3<4~3-4gU*N~}WR zyRTt}3JV(g=jjJ;-aTCTGGBI|yPK5w8q-|M1jhb_35<)?ASv9A;0`78TIPXoBds43+r!4$o!RdQiSH72Aozq9 z39=dV#BF_tk(R6fI@W6a#ILcIx$ByU?4Whr*|uo^+j|n+e3OkHG>ISn*_5}-eKJg` zQxfABy-l37PP4e@O}*{z_B`txU-s)?N^MgnqF+{CN*8w1ung~4e{%;LSZ(sk#zdx0 zaOxv{JU52(W>Myqy>Fy)wjaz2m&vI6h9QSppW9GqAK9DnQW$lJ%`IJ7eS-?=CHzYc^Or3yM){cgBdw2~2Da#% z?GRkgL@%r}f2z*mM9d7LGvA+g4H2zjAfx_55i$4@U1UH#rxGv%_<3lI*64W+vk`?s z|AmgnD6lFC0O6;e7G$t~GW7y$@nX(h=D|W=uvB}b4*gVHNY=aJ&KD7rX6mw^EAo=3 zGKx1P?-A)yLN{w8Y{nln`thCb^U;@p^O@tykGSQTRXbxN++h@!$mh2QT09o&C~THk^StIe!q2*7|cYsFgO>qHm4rFheMkyzCZJD#jX0 z3J1iL?VL%*M-511F9}6=S6gVho?a9sN0KKT2}5X+meOb(4io{0l(L^`2C)j`NA9SL)=>3vFAqhyh1Qwh!QA7gxgMu7iFzPtk zB#NgK)kP+WL_&rM`1el=a*hECb%Ey&cnPWgFc?CG4BT6gU+Qn@Bpd<2b>oiK>~p}%Cdg5gxg{0nzoNe*6Mev3@}Z}|*ID0taGeo7*m2&M|;3Oz`m zF<`+{B6+CWQP}{Lz;rnYm;n5Ms6bjHF%K-=a0D~dn1aQ(J;L$*%u%HldL<&OR57LBx`iE*kOvNq*@C+kBy?_7_2bDu*xTOw&aG=)9e;{q3DgvYo zs{#rsg#&$oftNI6IbPO`<#Y@q*i!pj>b8>=%)s0$@4Bqm^lsmX+ zv6%9JwgPJDuT6k@TmjcKR_p~p9Mnpla3Dfz3aqh9gD&nP;1#9jgF7jU1OHkB!OQd~ zx&Rd8-@O#)?y*6J#Yf{3p@0B+*@w2Cvv(I<RjKjdBvMr$k-ux^1KI+$jmi;`TCdvM^>} z9j^E4C*CoB#f>E*6tdkRrexE-E~zrDS6hetbzU8l-CCI4Q9$tQ?vj;m>Cm-2G#A|F zSpIW&7or+vxg&&=$ndA(fnDPyB=r}EWbH@QEVXcE+Q*kg_ilKzAg33<6CemIZ0RLw zl--xC&cMh=Oqt7NwinKJ7+g3x%+DUH>-AdM6eIL9IP>uL7G(Ve(+4+;2qxD(?{M{O zk644SHI6W7EZF$|!dDcnEWrGM7=rk;Yf{jIUo~IC{$t;DniNr8E&rRCU60uU)tFK7 zMg2mBSFQ)%jZpkde)l1$4xX>14{!6`6h;}##-Hbq=1axwBN1$(YgJJ6sUwtZ*^R=) zxjI;|*EYK`WA)wP*LFsaVxmPa@9|4J&Luwn0B^QCR(1gXR$o8-VOCmv^FFucc$l>P z6IyZhsyDVA+EOjm8v2nu@#sl*zTim^txuw-9<@1Z%<7)stI|r`CQd| zO$+uTArC7oubb`7J|Ss#DA^~hIwDF)?q*NYguqP;ie!ltpV}l7`qCcd+(*}*q|YiQ zhHN~i^C_LF?j&)(9}#e`tC*$ott?uhc-Nq)ViQ@9Cxu;sjWl$AOKzra+QGU$y@?Mq zXgX*2$SfwC?bD3qbc$~34I4(~*C~(18y236cG$&ncgt%B>+M!UPfFKc=P^`H^c%Y$ zme-aWS)<$_cR>5QNTQa6aTD2q=kSmC$c_^fx;=T=$@s~~MrDtgOv8QHzY>3_&kl9E z3WQ3NJ14{6@4z&o+1so%iB{=J{c3Ms9X!f*ov*60jS{#cq0>XF{Dc34F9+r+uMRyr zTI0!(A^%=m;Zrs1b=8uF2DIFzC{f=G@nYVT8dc0a(q^{KnfeDbygxl_He+*Kc9xs& zym-S}IqIydm+WoYPOXoP#Ljw-I@gbvI@3!rm^}igU)~CJ{vhvNc4x}tX4*c?aE99& zGF`UMAe(vFgZejbg|?o?SsEmB_0X2@@&cA|H91PjN<#}>&8#6rTE?Q-xWo&lOq)=_ zJH%yY&kJ}Fes1&6=&nBK8~d<(;lj{P1HD=OTuEz1i?D6&>1>~5(_PpPmfZeK6u+A6 zb!X#_kV;xi`g)s_8tAIr!2|&g9pw%i^X>`QZkGiuC=}?8|?HZC`nAd$vye` zyQ6WERmc;~FQ&&2hsa&XOZJ|KRrGzUEwi(gAC`%hkr2D#*Dy*B{JyV6|I^QhGqZd} zn;!izPhEq9YaAPaz45Ovg-X;twpus??Si2KBMkQPfN2iSSNz%m~Q&0 z?x0H$w%lU6alcKEZjAixx1JIYd+0WK>^}AKv0-ZBvtv%@HH<`iGTD6=SRI7V<@!9a zxirIMGbgxr4Jucj2bb>%{*99nfR_Nk0pO)hy5OuJja<7uS`T#$8MkgdoX@=9Q_$rv zHfBU}I;h#gSaGS^UY@bXfC!;DJ7CgZ|C}qZh5p!YF~F4gSCQcKs@)oq%wsq@ZY11a z=utmh{d~o#i!v|s)SuPtFtS8vnM)phxaS~gGE3>`+3XXS?X!QdemU2FfxlbemO*Et z-&U@~aG(8rB`M-F9!&T<>r&Xgm_E-id=d-Lt5v>2{yekk)tgx4bicEoxesKkL1#Y9pg3>fo>etIAo?irN!A;@#&DrGD@{2wNbZ+`vl;ks2$D zP=4c6OQW@G-ZsFsUCwN5bW{cJ9V#{QI3Ob6;-g*b6%3^Ek=gPQuiCRO*_uPF4+V)i;5i`HdL(Pfx^y#apab**Z!+>Q2HNt&A))yHi1e@nF$ zy+ezKea$uC16a;jtFilwANPlw?e$#GC46}@!rI((g~OWP-8S8<^yVC4Q`yOd^@n-< zY}XLQ8Xg5;*8Zy`R3d>jfMihi%eJER7yb&Yfx6ziESrVMmAhZdn)5Lxi;^+*O7xeR z#X~KQ1WqR+natI6qc4qJB|X10N*nbhTYu`~rufTa1@A&Xtlyx>Yrh6*Yry#!aOn6~ z+HhbuqeC&gY8S1P&-!$rZR*;!h&y06kO}NY;M%J zVxjTvO7eyM%CSiwPL?Q6ht|gKl~Y4K3>d?QyJgS`5AkOSwKi}{7vJAdW1sTeXxOTb z|BPs)2~$}vzw4wKM?xlFVLwkdvbP0kV4eQ9n66{c+V1qk`+i(eUD&u@oJGCIMB<}k z3{M>Aw{bPUvDp$Epq5j@&f+9h8&-SSQmWB%ii5~n8)kDUMV#5jQ8>ppI=uGr<+3bN z(|!I!U#2!0j+{8Nh9K53k1-G|^KXV3PJh|ueG(o(e)X{%yskg)vgCDh(j1}a$IBct z^lmv5Vt8YW@~u4KB|Bl@m|aK%igildH>K0x_F8KE z>`c0kL@9mP?c~I>Fh1jOE`Np3SJa|Te%)$l3V$-bUBe?XT=#r{ci>BPW}=_Cx89Cv zj0lQj_~ka09)&=m*bBtQTa@X(*~cYU=3hh-?n>_&F*|h7=BeDa@%Scj^oXzNG-ayu zeh?G(W5$o6=gjO_yQtWvjfF3Xp&7|~#>Xp^;#eM`?h8z@Uase1-2Z(q=iL|~4!-yY zeXK0$rw;R9jZDS&ZWzqi=eA`d>$|dg%c{UF`?+0WXV&Lw4F`@a6n(yNr>koR-*~_p zo>ANy=KNpXN089d@T5^-sBRaV4z|zfvq=&=GqBLb(7chJB|@HZ=-rjbG`P4)&HMBD zj`e*Q)agFDGyY;-Zl(kBhHwoVU0k~NXWwc57e_8Iw#Ic9Guszz`VlGGGu-$V#p6)X zW!JHJ%%Vs(S>a)Ud4muKx7Kr#pabi^WVB?tmn;E3q-vT(I!89<#jr5ErsT;Flcp*F5({eF)gPSzP(HD2r@%j2KF>#GUtpNlLS8v6Ew=CGAn z*cuAH*2x2)3IDF(dPG*R_m0hMnU~)ci8y{BL*&u?-A9p6jQQ`GhnB|+Bv0zcHJOug09)#{8LN-US~3aGqUgD8xx4Yv;qy61GEa(&9)^-5h&a`lwB`5$*=8Hg13 z(8`^yAa8nYb0aY6rod1*+**W3T;Wlk3I~^4`|eEKUCW4Vy-DM*9~fjcC2kF*UvDh6 zy=%DZQ-bUP;U|esk3e!$B~u| zwrTt0%LW?UOQy@{O}@{sPg65Y{F6Vxhx%*3{)KzOS76nO+Io7*#w)ODhJ;4=WsAfBuMDOMUyzl5-YjVr+CKX?Wh@*i;LHM#fy0(Zs$ zCjXiab$$?*{{I2)3~+rc@DpJ1i^JJ8eni3aM1(`%a|et)W_ga0HOjh zfR!Lj@xKYM3Qqt!NPuuX2w8?=s2t#5c<|j50kntuBuW7wQED*{im8*q02mug2qFZU z6cq5_RxZ>Q2PVGQ2dE3irKSP8h6Y$86mSG0H6{X76c2%*0Ki2&)8hfdYgOLjdd&)2J9EG&aOQ5Cew7pfPcv&2LC9L?8;%A7~^9xu?3NsD1`$ zYET!8&8aehnu7qtEy?Vkt`-wO;TO9AkR{ZNN)@0S5+|ev%i>y&0kyDYE>U~H0uUce zo60^&5B}*Nss+u7s!X6KDqokB1L_?-!>DCT9T4$E_;P?0VNnMlJ@`FPc@6Xe(uoy; zst=G(EC;HtEa?QK4=V#i7xfwR0qMqafOG|q!~q#D=>Vh$s{#?g@>IP5dm?^CSk{T< z7-9;LJB%7b=@o%0>&0@Y&OZUt3rMO^{14ttU5=L>ZOQ7_2ifF0D8oeW$ei*=|=F!cxuI)c`2a1Dgon7ULifz696Ik*Bt zErAwys-Xq82Nis~Yw{RzpgHeGOEEVmY}LGhAUK{hG@P!0}h+ zO2kSVdT|$-s14qfC6rzpaA-+>+F;*a6yC26s4bScxwzFWe?a=nm96UT=>{y2B|kM0 zE&<*p@DFVL%jgN!$NDed8MvK8u{IzXX#FhjZS1VW%f;dC?wl7s8^2I7G^`UQLOvJx zrTE;pU2mgQ1143757oqzWW+RXJ+DDJJHI~UcGb4AoDzE1ds6bp8FXr*Jzc%_@{Xz$hM+ExOP?hH z)9BYKtOP!whVH+=iiejp0^#;gk)*q)gD4CE?yJ}y zwE^E_uWaREYY(?3*|>mFmU;j+sogudla-;b*hQ43(MU87i9zB};MN@yCx(>TjYRGi zU9o;1_q4Nz0c(Q*1Gfy8e!!Z;fxXWj{!bpX7D7Kbc_|MCOcbzy{+@@$L;Kh2Jm3*t zohJqE?EH}jZpN(6BLeH>_dM_fob9XgFwh1J{r;g3iN-_j;NSC*XyhOLp&?uA_j)KS z5nzq0^1$Ve)ooEY=sLsddXV+|dmfU22X_%x=V77E8v6aBnSg{g=->0eC6hn;L*t>{ z`}cZiBILsUJr9Ef_m5WRp@3z!Iu8Tc8mro3fz`h{5BmQFR^?$)#6M`pLe|*odIVsS zuFAt9A*Vd_`^5ubaITgI5{F--9&}A%Ra-m~SpTc@fb##52VG}aT~7+y-=N<=v;foo z4|#x#eZX2=RSzu%T{HN-Efxzdx=sLph z<07$;b+{@I|K~iwZIRV{L7~7}vAQh^`v?EPZ98D~uNni5hL*tJ$3;V{$f`U%a^*ZI zo+Kw%J5TEWJ!hZKKKBY~E2&7Lq_8ldoIZH= zJs1j(gu7TCgYDf5N9cRI+Q1Qt78DC77dtpY+rrMq1C9n2bm6kHFdMQpRFSx%fr<;6 z0>@A*AdC-M9<#BczzK_w%3c&zJ&Faz2990K(6*pZY~0Ck9P|i7=-FFXJCN<*_@z>J z7b`s*3fvS7tE2)*vGJn75$eu>1cjwoaVb`Zo39`T^(O-m-4ONeHh{WC@)5c=9xlh- ztpGm|{a}188*2v(1s5+cA`<*zkyt4b3W-PK;20EAiiE)s2{X95rpob`*O!YH4g02PeQe9HHU_#Gzzk-7;6Elx(zABT>(U2XTk+bjEkJ(wAY8^chgB5X9 z-Q8`hm>EppPs6h=D`ac;LS@;#pA*lfCSD|;5Ai)dW05j(;#7aA?{~iBgMwjspS>|& zW-nPfRE?VK4wRnCM$mX0DPE3*&wiO4sQdiFb%N=b?5koUnvWh?AIaiUn7cM5ikS_Z zCP4!{`I~Q#-`y=_ve)?wd2g|@S_A8)#6rREr_RXyw5r&SxaNilO{kL!OFn-l{64Cx znex*2C!_7cMVB(4TRmHY-wmrfST{Do=gG21t-tS+)H>dExTaxF zl{xp#9iMp+s2~x1;&$9Wy3(V<>Xim_e;mOK!4Gb(h}pS03KF4ijZ^WbHE8uKD+9onzvJ)Va^E zAJP#rOhiqu#LPKy?8|fKYY1iTv7^16SG4Kj=g+UVm6l#sK9wJ+7u-hY9`MKjPQ$66 zZsHmp`56!2A-gl+a!7_6=h$7gWagn^gWEKR7=+Ym@ABzXYndjq`;E8+)_c9|m*pZ->|Oir&^1y0md`aBIMLsa#%2TVR`L z;soZ~0UFVMw`obk((K~~ca-&YkiKL6g^HT|=C;O; z%yA-5m?JUmeS}w&{q6ncbrCx@hd*cB-WV6spwraW&~Yo#t%%3S)%xPd?ftiSGv48N z>jZ02`GX4jV;dew?c1(;yndk3S7(z#3;dX^oBcPnTC$s{OpMa3OLN7k_o~QHS)DTr17C z$OyxlEN#(5%bBpuamANE+`qduh4@T2a=tt+y78c#=&Q~UGmo*)&qPfk>SY_HWgd@D z$yt=@pg9g;eYTZ#l+K2|J!m^~r}*2o7M0p@x`UhP%SDVz1{(%`%)PnfE#~dE?!@dR z`6KjJoZlR8l{u9PSnHbyH=Jy_u)PaY&{N{8s<7$UkF6F-d*E^xG&HNu#(4U79KX;f zJ%*B3hZR`Aw4#^G)v!L~GG%tMUnustF)w@X%alU{)_pBLM`UtxWgmYV-xNE=VsBz2 zt2uK^MnWtY-{+@8xT*O(LaEqki+Rf%M@^y8nIq@C5hp`?NxU8UcJ)K&uq?0mj%K&$ zaT>u&U`}yYUR@M9k|xAYXf%#1@jG z3X9WzLPdfvT@GDJc-qpOrDoWCWbD9sk(4jyEm>UZ`yBR6$Zkp5aXR1Ya2+VeCW^XMh>QmZR)ljl?#Z$7-*2NU%0-0b@jUK$i}2=!R}X81T~99AU=s!m3GycHy0Jz zyS0sW8N}MOn9W=1l|5+?ZY!D-C%kRs$qO65|NO1AD<-_4YdBuCEqA*Z?$OStaKuQa zUYGi_!#?}3z6^8y@-m%==UT5xwWTg|*w*t2W?NNKcw$9!=dSA41#c%AqT}gbkDI9~ zM6>SZoq5B5Z@(&ybaLN4c&Knrv+<-kP94R-3n%mRsW#Ay4(q;aiGa1+`T&wH@pE zP4`_=cRAFyvt)|TIrl(jP&{0XORwz<`6_Q6Z+LOV<5wzTy~mk0TQ?taiaMB8Az^n} zgLQ`!DWu>;t`)qHZa4jWEq{)Q&1Ppuxr4Me&>-N z-Rqm5rz^Q!->$&>ME{CSXcpi5`&ZG8jFG|&0MAfo!SAv38_J?(j zf>^%ZH}QMkCf@j}rs>p*Q^>8ZersKfy4R@ubk1jP$20exsgQxUWRonJdQ<}jGHTE$ zB5v9Kqgp;trYYQfIS;+k2wSp#&?pQrdH%)pk- zu9Jrg?IkbMmI({*{~D}zj*B0S74f9+-xw_RF<-*>A%)L6Ys=R0pLw;oV#Elnl_xsQ z{K_ZWpd4;bV@zVtbk%vG2ocQ>Ud0Ey$M!YI91ajn3iCh4U%zcC!~`CV65TNs7gJOC zO8hl;nxPcWUR%{=5?;xi@6bv>9}sFB5s&N48W`EQsq%TMS&5x~0rxF6?Le49pDa2o z6ch6Ftc4LOeoDg5D10iz>*L|{*?lF|3tZ#o$vbYaY9{Oh1M7uIt{EEnfMp7z&+Mrt3lGPYp>KJo>ilr9i`k;`&2EdLfG4=YH+!obT%TT zU=Rm?-^u8tFP(AMPR*Kfu%x65nOLmg9q7J4%l%`%gZQ4A@$dUs_PU;_-N-$&25r#M zWEE}rzffW@Pe3Y2I!la)2c3tYMj(fxlpa#>3Zrl|jcvS19L*u&#v5u9G;n8YOO~4g z=o_v)X?p_1nb&dL;!t0M4CqL-iVO@2|L?OS6R8nCgk(~jB?U^g`^^*%PU&Aqg(S?o z40pu8SC*6Sy1;=kdU0p}{A5x8Qy-T3FXC3Foki6~<3yd7f`=m8^T&@X_i^=~dT}am zqtSb1oju2&?v3InZ0^>G;-1c>JOB1!K%TgVEcX|p^`v3m7mgX(m%HgtHp*+gQ)WHY zUwvwzJ6Qw$gy!Q#deXDgpFLxyw42>K+IEjP>$sL6rgOYoSr2VIa;45N;-+arqvM_{ z2_@v==}kvjM0iLoorx+PDprmvg`!RLJSS^h`A+GJ`pl^>EZAdWe_EdEXMCBEtSR*| z(lx6#s(iD%DMu1hZ+7W3NxS2wdGnO=fRUrFWHDj;-ThY-%igj*z52f9%Ev*W9&(@g z{CD0rlIIBh;=2=%eK|hKT-KMJW1F~Re~4^&h3oL}4(}}#zcnTSI>xS=1T+ry7n7hv zBGaOn#(uzjkFFmp`f+X3S&qU57e$B2Ts1Ycaz`0nxbHKYVz1Ir{gu$PdmBDoaCpUk zYtmwU56jG^O#!wFEw0Qm8(0Iy#G^`!CU|5|U}oP-%-uNe{2fKj@a8s~ zPpswexbg!4VW(XnFq2lD(oruOjr-f<hNM_cepYA?(M}`^Zkd+Gb6tzYCTgd>i0N}@9+EhgVE)x>FzaTxDqN@ z-W70v@v>{_Xprep3@_TmPWg)j9$^@cf1ef7J@4YUE$jUwmQ9Vsay%D$f~~GAth4rt zSpoky59QuFZFfWko<3iQ8+M3g^yTnL+}6I^@=aCf6JOYkjfWZE$$A|j<0WEiPwYIK z*?6^*HWnAF>8|}@nmJy{%{TS=8~+iOn>!9f%VftrdZl#)ms5U(PVDR*DR~2iyKgD$ z-bvUdN59Za(%Z%R=20P`+5Og0*-9Cy1^gaE8)TrwAbBp#8d9^=A>%?@5qYUu{ zlZNfXt&J_^H$BpL3?$;X44lVW9MS`v4+h3cn&i8A?oXyPe2wk-g5as_KeO<)tSWC7 z!6RT^v4$AefEF;=zgU-)Belpckxb`iW$6${<-SK)Cy6MxEBh=sSEojCd3o__47FVD z{9;pgtd~#L+2(oT1K+O(9zE*w9-MZ~xs64A-(!`ZKZv;-b`bH%Dr9KzqFaz`vt|7J zozq0~R}_aiZ;m%QX!z^5#cxK(-+ROy#3~CTb{Z+Mp#;?1MnBlRW83=9Hbvh2E<;YX zi(J(AyzGAT9`0{c(a-5Oa+mz(A$-~K~oDAkE$tn&t=?$WP)SREfuq<(*|I$rK|%su_U zi;x|=Zq{claHwX*ypr7%PEpm?PyRL){Y0rKhdH9s>uynZYxq@;m+htI1(Gc=38TC9gH(ddg=gEP@8)cm>1)EEzvo1!r}Masb)P0(lx@sI#EdSS{y6cGob#2N zTz)88I`n2n@8MhZyMK7IH3!&BTFncAnQix`uStumRE7zu4iFFYkeNIK`96XSW*H zeCw5MqYdHcnJ;0*v8V4`uRc_GWtVXAV1|dl6}bBwjq{E7rRR!9U$@PfI&M8sH(EXA zRKo1b;?$+dX`4~(%G?~r5P0wd|7?#>Sf=s4hY4{4cD30HJfkt$9Zp<5w=_s%A}0g3 z@4V@P<~9_tJo+v>TC&se&XsW9JbR?!7JcT&m7goV^9lA;oxWjRCbeJ<)`tw=UoG84~7jR<$L2{b)Kpc+&G?a!*$7OT`Da zgzb82&7C*S^w*VCo>@=!J-mjUh*`siMdAKpL6*6+houe}XeW}+Ze=+7HqHcV3OxbwT$7Z--{m80J`6v!j5t)UxhSi-2kvVIPdST>-xj7`DL ztP{H?_~4{mkmc7iAtDSH?S&^eRS!^d!nY@VR&VL^y7S{&f0%7)Mu4=6E4}BN{!$v= z?Op6#o;NV(5}lbj1l{@6@5QL_UuTRMHqLqVD5`$a$p{nipsye;lYUmo#LdiWh*&+% zIiQYEx*ABfbp3L@CwpD|bHB&>f#Lf}xgQsNjZWSj-uS^8*|`RPYna$59P!_+J7sNo zHW~qj@_E=bzsQkq*FI2up8s6vttoyQAbd_!F2AmiWu(?~kaBW57W>8G(|Njpr!~$a zX!+w8l9&Z!3^>2Vp4oa$a?XJA>Q25x{taoX!&dtyP5#!Y>xNuO^Da~HM=&Q>9Rs4vqiC#mmPTWQ znO+qo{Z{(HucIxWV%vf~upMKo5|Y>@Y)9}upC`zw*B72k$`<=^bvm7IXBu#~k z&v%t{Sf;8sK5-15xydxc>1dFRc`(}Y_=$hBuIUcJ^^Ekw2Y#NZwL5ib7SWaM%d^H5 ztzjdh{z4Hk_g+c#?j$e{zQ4|1^pLkY~!LnoOIoAC7 zyvv@43j^nfZJ`I~r&2EkE%Cp+?FD7E{6YM?`yt`GhKxv zX^7K555duze=>tz!^T3R|H7h?x!fEEwsOT!w-uNrz>@1@%Q{r>?S zxqOsU$7XTmN(Y7Ny%NF>#2M<*+Ou5TYSpyF|>;JTGFzUa%PI%p9J79KW> zDP==-Mfro`N;bzV43FzskUb<7T%4>SopAS{DB4@N!;vtAro~bkjl+Q=-~d&!@vw4t zaHY7o!+~-_j!%T%aZ3u-T?`F`fgui9I0L8af9xN0qOj;~l|&&4NVp^xAQ~tV0sKQj zj#e0T94!lqy93oFD}_WtMlAUIm4ckHfI@BHDF$9?s0_GqmC3-J3i*+L`>&U?U^?^&&S5y1rBZJwho>) za6HISf&+>XDsTeGREGg_5gKp^w9DL0n7sLFT%*dnqm(e$&3E--~M%~fR~RGzh`m96i`L=FY

u=4;MFN*cQ zl{_3F?_mXb%1LMvU=_?X^dO1GfPF&I!d1+_{1aL#Z!AhQN=VSw>0ij2+ zaW;e!S{7bZw%{l%knswhmNz0(I39`ozY$>nr~P2S6k~CCI1XUeL<|v*C*ffzBoYlL z63}oA5)TRq)PO@{s1-0oa6146t_6S?gU3M){!@UhEW-jup#TlC9B`lyFc=a}00%?@ z9t|`K3krcmh!|=hQmG=qFc2{y5P-7bF$6e?gah>ma3Vly2^c^a6a%FsfcF4T1jGZV z9_kY)9R?&&pa2Yj#=#(NAgZAlq5=m7hXNjSOaQMSjs)t17*7CV!7SzxKxYuZ00lf~ z3t}K5EEoY4;E4ok1E>)eG{ynSu%I!u42K3$6betEHUT^>s}GDiBS4~1b85hlsJ#8^YB2#32z7xZ z0m>luA*ukC%bEab0CnPk#zQeAu4O_Nd%#01;i>Y2G~rkOP%TKP*rfoTVFaib5Fp~9 za;OZq)B%_rsP*zcNE@h%04c+&fI z7!?0@BCXuxmEcsL64+=|;8gz!NU6iYD@omaG~i(Vs40CoPzY+u0*+p~OJoB_Qx{d( z;tGyl+Oh25V1c7nbpS8=Vu}oI`z)qhp#6bb>Iw%h7)ZfZ>=D!>A0T1HRsm*_THPHe z%wh^GxQmPV(l!H_r*;FbNG%Oa-CC%&%*v}n;BEN>KtchP`FHQx<@>CVrS{1%U&zl- zPNvD)eXe@JQ7+2Nn(ydr)*k=MIu|f?c5j*jp5|j+T{uG>8S(_D8MHP;)zsfl@M2zKHw<$BAo?|~O5WVYqr=*%Iw_jJohH+O0uHkm)yYG3q2q#IF=I=my0gT%0gCAY@W zi})9)gVy76<{CH?ZTmH`y&GOF$m&Gz1dt3fYifQnMI`Cy1sK`Qn5k$FZ|+>DUfk&s zezr($kC#fu7@^*C=|{dbBkSUf9~S2kjBdMEarN9v2!wgpv zU>YL@B0le$6m;cR$(FSJG;o_HK}=iIw-~dloz-8J2^F0;D3qIW+rK(QVS@Z-ET9&i zt*8rc_1PRm8NY##V3+1g#O$*mScTUpqv#VyDK}&`3X|q*Vdp%y+l(8ki3H!;8U7v< zE_QAA$z(4si4PC)rXrCtL-5zSy1`8u$C7-<6L7#pV>Ld?KIfBqKU$w^HWgTOFx$3bBD3FCX}v%2<1v^m7p5cfb1% zbLndtwEUf2@5B@u$vWH#Z1SuY!x5aB>DtMMY6o;aj?JR!j=6+pFk10WZz`fwa8|3| z@Lq16(s;Cf?&Wa1UF_AIFYPROj}CVftiR2zuM~T7{6SDwYi4MLQoZa!t#6{Snv#Z% zWIgU9-=jl2Pf_T0XJIF!C)*7Q+Zm06z1hByzN^g*cR2~1mnL^j2EW~bX+X2J9@QWn zO}#Rx`s&4@_sn{7HFX^7wd#J42e>TP^v|V^6Tr- zvPofLK4}tncoHg9F!wFCuy)PXJ*46J;a;%?d(&}ek@3!}cP*8|E;@P0RMK{7erh0f z)v?z)eM;&|Ex=%Q3!Lq}cmCLzoM&OxlxuPFKFmm(^BQxy@=WnRc`LMZ)GyK?naYPb zzsU(4j;gq+7_T_I(EW%dkVLyFFET1N&X{pC>f8hkPQNdVmXO zr=HH7Zl;u_f_c#Pwp7;7GO3Pihw>{YND|lMJ*qbD2rQ+=q^`F*t&T3sd>13Yet_vX zFN;ZfZKCq=L_D ztvx$i_+jbjJ(A*gPu9OD`+wWlto!*#)9f5y-sZ>O%@Y5VAlQ@tSlxjo@~El{F_)dsmc=LfI(NzzY0mx+X+!QeKDK^Ube%aeQ*HT@ zbT!z&=@B2p6AM+AeqK-bX$>dU!li?aJ7h%GgtTYW~n``*kN!Oona<*ylM9# zi^&Yd;}>t7x@LRhJIl=F!8raNfqQygu_w22#Rhxt=gUu!m~~~uSFKB6^I)8qW%w-a zr&FVpLY|o2eD4sO91q8Lj**1ySYyi#%prlfJ)Cb0Uy2!*X3u%S zSV_$(Zi$`rEwaJ=}zkjcc_dHJ=ZTAcD zyZU(7dU-u!wrr`z{IT+P;O>?Pk<^u!hN@yd@J>Uw->B?CTZ(-D9A>ZGiEwR}1Z{@t zmmA<0Ue>6blfBzkBSllJUVg%A|JOuou_{_T>`SH|9{|V3TMS*Ee}6FY$X3Vca?DJ} zD9fY16n0B~7wc5hg5r6?=EBnp>yL1sv|eK>)^I|AqtL(3gmNgb29OL&CvP03^%edC ztbyA8YUV9MF0AIz{?xOcLkK-TbFxk&I@l+Tls#*Dan^y{8SExuHAt zX>;_o@tim3$JTF9;IUnUv^7wE%-=vxIIx>7daq06GCm({ow{`^qzde2?3J9F$fJW| zh(k{_syc;UWYi6r?wA$ISLna=A)!ULVV0|qqfIhTwT-nz5=N`TE@K>z;N02SHa#Mp z>AFAJ--w}vV|zt_`nrH)a~t_HFmUZ*xyEu!v@XrrUC%un+!%u0g@ezeS6mwq$xTn4 zmN^vPJP@Uk(axuJU}Sy6qXzX9bM>z&@p1c=BCm{{&R3W|UlX}kRu$FcH;xZ>PNNeZ z=FbpnsppWs^I$`TZNf8yk)yTz7sNu17z;D`ohD7$W77F@2f2HY{mn={%hcC*=sMq7 z+MNCPb`Y0W8}vaZ%Dm3?W9;J-3?24Aw{tyuWyKlkr+PD=joE=%6I65UFtOqA6g!Ef zCdeu&L4wK3UihX@cyLYowZaUG#s~Z+GgF)OM^9Z?VV`%B`7i?gZTk_GfZmGQ9`VSI+cT)y%XDXL+oziiVthIf46 zRd)>y){gM=^zT(;BAt}*)Y&nO5k;|&^ln$~llK>jj3YJNqf8IXwdbe&d>%@umfk&T zYI4ZxiR|_d(Tx)5Q6J-J%2d~b07mSmwC}^unAos3VUdj+b9+hW)8cguy-F0Lm>;7a z2u!hDtK;6Z|Jz=U>PtfGe9;dFSeR4K9O1tnnuzb;@GfniGv`K@H-&YF%ltX_Z*`2E zU7w{n;y=2OH*vSByL$)U2fsC3%>YUHV|V+jdoq$b>h3i14CQU&)8}k&4p?0gzc94W z&G2X=J#&a0#pF#&Xfj;FsN!uzwtd|I26c9T?t-s)x3lq(oIYILN*kBzIpH(S|J*HZ zQ%h9W9VXkH&EG@C`bHXFqqyx#x@|hQjGO1l#LGA3nAHohZ`FKeA!x_4FCHyb5!I3%#uphIi+f&p=$FoHq!|M2&k$Iv;7%B09E@ z8|O?8Mf01xNr`Ml3@4xcl|Yyxk++pJQYW3U^RA(Vf}2=53qUm3Qv_8h+y$`Iuuk3|T{u+ll0RKYZ}jcw`ear8*gxe1V5r|Z_8Sp} zufVkxv~+Zo3|HXVknny%+Tc950@7Z@L|5Rc0NJ9RkQyvc*K!P~g)K{o z+5;BA`+#k#_#i#_)jw1V;)$wEpeL$amy`qQ9X!LRWlJ59@Fe(hfD~a-2OvH8XQ0X& z=mn$`D*{y?Ae~qaR9#up2}mDS28b@|Gw1`-jpYF83Lc3AbGW1fkRGfGBmme`^#bgP z_!VJUCzfMKC}7@U)EG*y2-LY=EQjj+3XonvvkJw(kZ9_1yu8*a0pH2u{-6wA(Zv*S zAub`ki|Z<2n!1BfJ%#`?Tzr6*Ve080K*`{HuS-BPbqNMqMy&&_-N4mIO;eZZC0KiL zB?st|kRGm){jJ|FI%~cz;*J{>N%aDd2Af6`@#GcEy6d;t=qgmlek6 z@646flos^jE(%c#yeUi4)dCV)60H{4w-+b%cL&rK%hSJTZ!CYg`M29!#l_tjSRhM2 zZ4z7(yi4E@Z2e0p8Mtc!EQNph{)_Wv6iYplf%fF#`y0Dz@v>2Po2m%mi_r@u!y^ZR zM9G)^XYO47y6bhAir=I%sYzAh${um`d(SG6$Bw-;alUTdP((Rj?Kvs+{Q^2MT|zN^ zJWy`FnU5>8pYk;zQ<2T?Rqg|(kXZ_?%UqeV*}5n?R3)vjNL^Iei#hqCCtEiut6KhP zZc(N5znJ|MnzrU;GsWgocAPbcqCPSJ0A7{b7LP;GP6^ zfn55!9JoOE`<$W(zT6Mpmk|Z009Pj$ij#vS+!G^3l0r+s?I{#j4`~F#`B#yYi@Ti| z3<2)8SRc0nzZI~um8-Qa+|t6z5sb3b1E@*u-ob^e1pUH-sFE}qiN+x@NE`|Ub~2(k zk|=^iiioXPKfT;-Y+=CKAi==>gQX9!=5S!|vxWc4gVsXm11B%#q40QMEv(7|#sXmV zpY?E%DY7~b2koS*>fwNU{f|6gdHypGJpEA*iNyTT9}*4yWW(yVXyj`8Kt1H2wBVsl z8T$O@7g%kpX+hz^-|BjJ=mOI}+akeD4zTh5qdz1O3GN83$|K;RUHPB&NEl#sugXJ1 z7b#Zfksy!tKiguEgg?f>ph$nvg26x*#j3VgV5Y9lBLXXXRUQ@%to+q^ILMCT1_9gJq`K&q0g`WkR&{0sIAH)LYCe?@*o3$ zbv| zk>MCh287XJ%i}gyWH@2zR@sBBs!O&Y+rY6)30f9pvW+VVj)QJt2wi&%YX_1Y9KZb3 z)!9ndh730a%_^w?QEWWOaD=)O5J6!%R$PwN;pVHvLH*tVNNz}aR~sPRlK2Q68#m`u zu2!Hwko=&1O&e|B2!=g~)XaXDqMHn0si6mlCpdExF=pm@a4NhE=Q=UX} z26rf^%xbMbk5|-q(1v73wuhsZn^tjf1RaMXR2+delx(b=t!*Ir-N>#s7EUma#Aha& zZd?TBhjY0@H@)eU!{2qb9+A3hePT&8*mb zp6f{R7H6R!Ka&@>3G~5io<2Q7pZxS4O+*$~q^98JIEBDdlY%TYFCwo#xmQ@n;gsid zSa9-W#3z0gG*M?@$H)gU{s-xyN`n{ZZ)LwxOph$yb36vC`Y75u`0@C)9bVn{+%m7K zxr|rqQ6)F^^zG^7JlOHd@ex;NovGT);a%>%etU+lCfzrxPLC{+?-vQ`eLrXxcsVzi z5HVj6LFan#%ic3bOhw;Csh>6LmkbiV*<-ER@e-7^<_?;|ehhi7SUXlXKkSC6m)Yv} zlWS|Ne4gFu%A;?Oygd@ynJhkjZjt=s{gsHi^|sBL>) z!R$2sf=a&ErNxywk=0&3i!k!xgXJVq%Xm`rT-$u_nEqDJ%#xml<4=Um)KRSZI!{lg z!Q7j|oHm%A@z0!MLh4WZTZz}5x=Za2Gc{?ijf6iwARWQ@Hx%IR(o`TPuCSgsk; zJ32W_m-}F6-baT06neRdC})1pq>L6s=V-8!nOO zXtQ2usK0dCJ^8bEHsP)zjcr+#uioRwGpRIugLJG%b&RWd%laH%zE=uVmcUDmxeki4 zaInI}%)$g7d75}i`)%RZ;2)m6`^LkpY+zEeDBGhvUP0KO$nYlNaK=nv*b3#74_z&ftgC za}H}H*C3m+{5cLGVkX)J2a{fudG?*0O?W}z%paJsWspjk?#Y#uxq*yp^8X>L*3t3Z zQlVCJ)cnw{+CH|-!&BEF(&h~wGv*4F0}ZLUJ^?z8MVhNLs0 zZJo1aR}M&fJt`Z>Qu){$apv5tUsaHij7(65;vp~T3&MFE7N4)}F>F<;-61fFa_g|z ze|77H{8Flh2~u4hQn9emwahM+z*JLeqg!BR-dP(ZWOaIDoG=K zFsp9=#YaV&m)_;}bU8R((2CTQBYceV;ZCblb0KUBp%LPu)=YYMW!AXv;s?hs&kYH@ zX9IVrs(J5XyR7_(bKk3?EHNg@gzVObV^Z{9JEEvXHs4j{-EMyL&gPuG0xsuH${$S- z+-7bw!xt9SsQ;0~-SY5MQV@*+{nogaM;?V7s%O=PoC=KgIHAaNkT$<-B(aTl2dPEE zFT70^vE1RtdgIoK9VRs567>4?8YR`a#`^cV%wqHkr{{R1vX5|nqa9ag&oxt}WvmMB zjxqP%d1297g^E!!%vQ-p$Mj{j4(tSPHlEZfP9AmO4QXgRruCI?$bm)Rij$k>8C|^x z>0K|%OJ-&V+U1jXbWRIaR6Kj!X8QtJ6pQB_e|zXo)9}3mJoM*3`f54U=rnM+a&tVQ z4=<%=e@rE`^Qk_%5@W>$?MH?u6`DOK!>;tS-TzhOc zB&lxeE38}C^J4DaaNqPXTfRlB6cR|upn1B+$2MnNGEGC`xz z_!SICfiF;^DcpP|4ZRvjEkhU_62O)}z!wajpo7%9N4-sjg9bLB5(^U!Fq4xFjA`7= z+DMhp=z}$KU_PkKY)E6cb3=&)lfdsHq9LHP20Te%|J*X}nFdQ0l6phSLtTjrYSeSX z3UWD@D}2Qh#U9&aPX0K21h-~+l6*&*5t#*#U04RB84;lt_Lz*Kj%5@ zlbRh}*v~BOa=x0Db9NosAauJ%HVhW?7sBaCkpxhTp9a~=>l!@s$TE2J$7=zc znt(!)+GY3e)e2G-qWAJ!hv$pZKP%^kp7*=Mwf``?HiFG=gnCM1L{j#!ttCS7Zq0*j zNxJ(F&vW;mRy_DTO7a=2klm?A!bXK}cWMYJ6ul~aa&021{Xo);ygS^0w=48kaO5FL zZShhwr9&}JbZ094Up>el7O1w<+&^ouJvZEe_P&-{(fwBlXUFOUk5o^QV)QP4yQ8!u zS~ST?0;z-MH3_W|%zBfCnhjNdU2*RNl3vs)^;P1Y=9dL8`jZZ%oYU6PBA~IDh!@c( z#Y+`smt_1N#?Aq#7iDRZ(K~vY`+|A9xIZ7`?@KRyl2Ts! z^YT?w=UaN4J;*V3VWr;uwyJSlCWn#_Sj%+umG6DnVumABn%k%C?QJTzFXkK>M=VI6 zPu{;sCFy+r`a0tS?Jm}g5B@KP2a}HkR*T$hiE(q=?8c|z&!#A@i{y)A5RBer8*g%T zld)iZnwsb)xRbRdb2=|N&1Gk*l;1w44QzMV)YlOM+SRNP1B1f<*-T_2HNyIljEZwa zfAJQd+5Em~y_=}ugzwISt??tua`GLQ*f55#?;ZF)RnYswlll9XeO9LJ1(k-Q6537o zp9t;B8$G4m!_j;8_1Ron!x3d|sZ%ecqj(5gIyIs=XL2?Ne0<`UyU$IQ^NWP_ltJzn zwprOXA~at0@|pw6+s^h@o_*VytbuOZ^!YLk@ul~Kd(5;}qibuk$dHq^OEF?5`^@ug z#j)mF%<5 zAV7P`HG_c05dJ&}IBgBmW)$7)LAHvjsCgl`>xMF?T6@o%D(X4#OCJm2{VGBI-XoJQ zCtP5@4cD@{(r>|WA)8GxknfAU{T}qz0gv2?4lRTDS^lGf6>+CypGXKDJf>Xu*a)?M z0(LdRnb1utluxq{;7Q_nmZ_zD?XbRn5|MjY?#m4qbp0-`++r)Osit4QM+cbVgzaCI#G{3MLE z>9m(5A>AzbAuYN8<#X@AqVUqky0%yBqd1#dTm#-MCiQ;u^8UG@WmaE*9j!qd_O)8W zq5iBj8Eqno8q9D%VCRleeJc2Io&Ew_{-U#@Lu8Jc8d|xv6ff8_!J^o$)L$DHlKNod z_$7yTJa?unsJfVE>FND!6`EX_WHxT|7v2|DQZTtw_B3X0M07qaz|mcSDfLoT(C7eJ z%^-b;&Gf!I{f+8n0uH-!H%VoUnOgIi*XnLN)E3f0ioJw;ooJXYA6Lj6@+J8|J0cx(HDBlgYsGByea!vK&vYhs_v6k5(4HtC{DBM@U~CC{ zbmnUDsbuTKJbm-)-&`fLdnbCz2JSFeKoe@=4Y3Wv|RuDu46u^{UE~58rlkMfLqGp8j+h&bc%Db{$W1y^KRD<-L^%QF5=N& z#YWQ=NH}C!&Ap5h)DblyORkF9T zmo)~_Cx}$Y6;{cTJ5C;?Ki6wqic+0CzU{mIu&d;I&uu53Or|W1sE(F-9Cy8S=ymXJ z{`A_+MK;yUn0K=DVPsV;z2tAx(QQfv*-Q}+JqimtpNCyzi-~X-8R z`)JvIDhEat`*qZ?4(LRq{$l;cX_Yc<0+Wx7KL_2hk$mUt3>UCnm7#lu>U^cIrK$}c zLDy;WAk|L?B~I1eOC|9pTVN7~Mf3ty0?h^IS|kgzcZs$8J@}Tkn5N|`<3*EFF1$U1 z!JLY2ZtJ*(>9sT7dL{1(eZKCK?e0#aGS)o@H%GY)Sl?uL(H1kj=>2)}Gb#HkC#lRh zS~4X4cK4AxwIV;xur&JF330k7MQ?1mu3jI)EI!k3<;svJ`PeZahD~mavHvB_?EK8W z`I!6plOr4SlGV8E)VfhKl{EVvR_qDTJGPE?)&a(7z)SuSlPiZyK~uA$SMr7JMpfUs zWt*vk*}A?1%7OLP9xC{hTk{Oo1WNysAjlw+Oe4F z9J6D`L3Z2QMJ`N@p>+O-$9U$tJVP^#9z037%4=7hwYYOQCacwvqw9_aQCP^!Z`Yo5 zXEdh)ujR3UtZ1?J6Zhi6xN_~023z%*o;{o>U)cG!#-J*A0FymGTr%?Vc3JOfwG?CV zht1g%nGard911SNTWh>ko7df&aO_Tax@W&2o0F<7$>T?Zcb}L0<>ao+oHvS(ZAsx5 z34~p`YK`q_=X-04AD*WoojbA)Y{#qv+flf`uq&A>Qp_~~kDX4su$}JMkFn_zI>}pe zNB0TX(<)MH?r$9k^5v8;;!Fsd@!zi7`1*TzhV;~)>jK_B&v>ag>JwfSup|;xL%jB$ z=n_jiRa^dIW1POMDLmlD$q0srDRZQ6xDvIO+wvuuavXirtM?VFD_kbu1$51h~Q%(c#FE0+_9}X{Uyf2Sj2l!&v0luieTDA`pnKz=gewv1z-ynRQ z@3EI$pyk)|!9sMG?FA>ItG|GPWG^%#W(GU~yxaWRq2F;w3$w@Pheu>H%_P6Q? zrEC5qOP4PeYis_2EE;KhS-c{M%EL)GqmWUrZ6>=%pifX#j{s+@+<@~18( zG4aLdvww>{zx}${ygvEey*!7!G)b#d45GUZFeLl?@kv>d46`(0d*AAROZKe`yK8fk z<*J*)j;4kvl{=s2bbXFwIK^ng6qyBrgzjJlbFypue$osSRUR$r&XU>h>?9$&PB+)V ztQhQHY#eZr>Leu;of1F}+wB_;@;sluX~2>6-FX@gs2Ue-eF^8M(bzMY8ubxox)tPG zcWxd2I@~lK+Z;H?a-5|?K$Kt5j&LR*mv5VHPt=YN7)jZhFQZMCtzLCijGcG($_!XW zU49;R?x9GGlV`mClX}skKk762ou>j7JH#%jYVRtrzGa55tqi5A`KHHdqPJt%igIt=kbWnT&SiW%NGi4ZD~1$k z8NM;cHl#PTTM%Wk%_&7mZKL5f4c?&g=<&M1A&|0@+%HFl~!k-hzt}n6)+0uC~mb( zQLk@15j305ILm%QKMV7CxanD&Z=;UsZayjo8o@(9&sW=>y)uXB$U3)kogrEWAfx`m z5Ha{XEu>Ekn<4-KJe-@>f!>&PU?U2H{tFwwBG0VI3kE;&0w0}a>+~zE`K#~FQco7& zUX*AKIYcx4-2e86dmXQWr%cskCW|xUr&I53i5J=1qlkX0iLe@bQs>1T@%D_47@XS- zS8&oPGpZ@FBP|Int{-&n?vrZ0`hzTHW=8JE)az(>+z`DO!}Fy3J@@6=j{Koig!fNl zIQro4YC^5XSWBKnihB?NI6j-DD`f+9Mft<~lx&V$7@X3zAi0StI6GQH-F0;%E81JQ z!jUk4S28K9f;RIZM|FM2h ziNcbvPYi`5AmL(Ius25$3E&TkfCC&Dr5#NRva18dAt;VSLcj$4`{xSs1p*1RfS(fh zL@Dk+7($r@ym^op^OyH@B?*R5a1YavhSkl*!V3690m^fMoRf|=w!p_p zxpZ}~vj?tIut2H9fj|fixCFS?1%6OC5o};B;6x%EVFlcdC>X*9MBt$vh(JyU&;?K- z365|Eb%7fIYh6K8a0D6j3-}okh7Q(bd*HHMa?Jj6q zVJIXL4VNIG;TR+yJS0#84vnE?z(_!DXdE1hF?bwQ;6DY}>StKcQ7Ax-tOOjW12l$& z6TmiyfJXz2VNJQhd|u5myzN@buDC7bZmLQwy(pd<=~Cs2xjKCPG!j4~oXqfl{5zz`{Y`=_d< z3rIky3UGFSQj{VUkaEQmAPb-j9I$vOhSaqp$Wje>s7rW?{vb>Er+z3Gq*Uy30QWEg zR0{}@aL{w;8E&}(FgQ@@l|PU*Pz(VwhBX0&l)!{tTTFWFn*LP3pg+c%9Rxyy?kiJ z1`dvlfHWAGjM2+0lLMSWHh@}4;4NP&;SA1|EM2)k>j0&eUn>A*uL9~=wGdFiJ*dJJ z4u+7D9n7)IjV>)B;1#8$gX1?#4gZ=0!OQeJwEzs`kGn>h@JHJqR55OlC*b2FC(~f< zI$ybXLN3b8n)}#Xrj&2E_9aY>-G>Ih7kOA0XZDa2bh&&pbebEZs%jr4crY?hrCUT; ziNDp4)lZMPn-F@v^P!deONTfdlT)|vRL+(C)wtf96FWzF6lnAIl1Vlu4Dm+NClaNa zZZ-{`JM{Xr%#NJ2_H2S{SEr0*W4qQ7lkXRw+ZX&0>O_>Gj_melBhvkjISbF2TMOp3WJ%B!7+Ln@+OcqKyb_qs0X~b9{#g#MPu6xych=(Oo z%k7Pl5k{c<;;p0K8j&?ujh+B4)n@!78ju8AX#gOiq)b6(oMIhFx^qWix7^CKPmMPuK`1O!&Hp zSF#6(=-3mysZgX$Km5IpPEf<`2Yt84`eDPf`4!@3{=W;1N+HHL(<}sP4X+E*wxgU}J!CYSuo>nuoduViy z)jmra5>d){cBeIEULn?>HvG^;3PZIw@n;{x=TYZ9rjqwEX!(2m z1HuaRByG+F7Wr)!gAwc*x3rRttKVvW{xpZ)eB3$YHiH${41K|71t+!IjU#d!lt!cV za>B#x_*pC2-`H7l9UE-DPj#15Pbt=G^l@P3^Nf&krCQm;n&0-u9uzaEC+Tt?U5E~8 zKTF;$l8K#)o@&w0Z(%SBI>YjXxS%#a*x|?D-gq!#~NdO`e@9ccn{}`=}{DQO0~%C9k#?Et?c7?3pTh zZ)ZZe3g(f;)@>bgHBUC}{NY-@6`OuyPl3^%YlW6dp_d)qWFAs?92~DBcGR#|JB}xH zq}<0~M0hWBKL|MfNzOgLV%nuBc|T?-)oGnEUBPD%U|uH5|LU#K)K{TO%EvSpTMP=qAc-;s;cAZP@Dl_(sZ0K6NI=Dwy zdtN6)+)}|ja94B6wh5V(6D-Dg4=0JDH{;zZXm|UUP-9Z4ti09Hr5OV;ysU>9PjNAu z+^Q})VOjIRq*Xe!hxA>=)<*Cm(-mulo*h1c@J}1>?+}f~<5|(h&o_-RH87bR>&)Dx z9eiP5bO${-)u8rH)4TkE4tIj0*wniW@5yiW264yyTQz2lJdXOyUd@bm{~W2{`Ce0M zPZJO97FtSdU!hm+2+8-`{zjdN9}RQ!+__txEtn;)qrr9di~w)^@ug4>eYC9#PDj0{ zCr=N9bsrXty7Xl+bSwJou|ej~7dJB$ol)7nIY#gZyAkaHs~)W@au46TiZ%3X-r^#3 zzPm*)G4{o2hX{3j;htM8XBL_51j94Vv|1(2F<5=)lU_&4)tkYU^N@ezmj%ovU~d3( zc_@W%o^O+Evn^T}Kj&=q%2B84g%RAZjP*i$_l2OL3zv!n-fSMvp|-PH7{-%uEdA7{ zd<<^&*+=Vmm)1GNkSFc&Zu=N=Sh;PPJ)>`O-q1gOu_y+UF;tW-Uu6&WX-$tm(tR^qKt=hnUyaK z7gwF-6zz&IN&Mt$3|l05Z`^6&FEN@Qtn}_o)uw~|KdpTnn`KQ$M@E$K?g0|REk40M z*PiiH$?HmF$(C4rKVCL47{6CYg4mIpimG5bG@zl=YgqcY`G$P&JVx5?beI-%f)?G( zn~m^GZ>m%-$`&?PiEk=WFFS2@;A^6_a0N9U_9a7?8xT38O$N@d79I~hwbgbEkC|;9 zW`5dpgVmD9**eAae$jWrmVEC;s-v7<*6R$#I_?B$BmJi}S2+Zr0VJK0SK2Y^bAn$0 z8mR8AWZEh~DiL}kW5!J%FHFMJDAHVK6b&#x={pmPWH3|J3QHQjY4NgRggSIKO=o(1 zOZ4^8><v|UHEb+~*C*mV3+HXK;ZAm?h4IDZgU;V7zy4Mz z=T^#$jB$M9+bE6OE!>)ihN$YE)~VkxSO0n={^|jx$hc45c?vTDRgu!Ns;CB^QGAe7 z>Sn<~p4$RVwQQ329&aqSO?at4bgY`^l5mJ2Lw*L2<~tk@%cRMYcVm>eXk0;{eck*GT| z%}Qjh3baZ}5M{Ko7fkmI3#w|lo`2h-{xOfq?DQ7B;j@?48Hjb@F$Usf{$wOLofUt# zH{6Hxdb|s~t_zh+nO$s~PLg$^rA+kP8!!3m6^@d>7Q}S35c*Hs_}9`mWU-9t1gK?{ z8}Pmyk4NVg{BZTefQxP6!U z7dz>m($X*m%aKh!+z3gAiyD@HjL5RDd5b|^c)R)1xqY2ZM*VVnaCIv!T#Eao=M2xQ zlUM1RqB`y|+GcNA2odfXs(X*(v@7nkY2P|(o+}eC-;ixqE5N$_;7bcWJLdiIXz>C! z83KGz#VF1qoMf;Q!xDz|SL5}+bM4l1YI&BByDFi0SGMoB5m(3o`k@)9Q8zKCqs^KU zCt600vZwl^dCX6W3vEYabIr`+H*ME;$exl@tNKCOGV69f>aB;y;B$@_eO}KRrmFRh z)m`f($>Cpqcv~LW_dO(caPaHGrlZG91J^O|^>!YBP585cYZIBl+B-VO{!@-W1mSTw zb?>vEmCr(24R{`z1r$W{#!u-)W$=E~mGz7lPkhkR*>oH`7+WCb=2KEsj>w6u3bGCK zd*pDngepOrO0lywLtAk>PfPo4-M#mEsAVq~leWCED)fyj;vEcvTkho)m4BA0%*x@^ zEOZN3c_esOZ`|0MC%U)mV|Vz{P#H*Ybq;k-#K;^LY$da!+p};+u3}BWTteBJ|C^(RFLTttRlSxu1$2ZM z7C|@@t~*@Am0Juw8Ty?@zMVxp8;V8>0}j2u^Ke} z8iZA9Ttz`0^s7%L>vqRz@`Fb7-Ecp!4rTD5+Gg=VwRy8dJgz6Jov%y(kL7x)Hb9*5Cet6 zpf+)!%)gjiNI(>1KTt~$b5HR}QTz;0*Ptqvic`h}Dh>h+w>)P5RJC*g41TE!Kv_b? zC{h8*q2YwgU}d;gVjwMSMN5<#uzhjKxEq8JmXiK5qKCJgD~>iLK-4d-56W-~c?oD<;1gV?d6(u@-~ynmAe1>5yUg@breUBC%FIpzE|H}? zlqs091qBsBb2mUsl){v$dYNoqn#sZW7s?Z8a;E?-utZSieBfw>9;}+q!2OEAD!`+$ z+?EQkwWZpBWnB%}|CkX_yuY$x z|6?|!I9QxOMktn*UDd-?80lAERuPWB5?3=;n$U~8q(n{drYyS#G=YYe$4?Wi+e?G@ zs{%@i6>Tmpbt~V>{_x#AG|C_9zXuZ z#ObDWT>&|u(tS#N;SxIWmZ;*bQGdDbjocg=z2vWc8Hy}+?{Xe92G5bHo##uH%{D}B zMm?k!6sm~|eLXK<(6*gkS=I7)xJA*@|KRSgvb5z7tp7)i{;GHCe_;1llRitY{?bSF z2n`2o=u`+;uArHh@}c=Da2kR#K`wu}1)$NdSy&N#EDjxj*$Z|7E{@J*M+Zx|J4T!+ zjuwU6lgTb_k_d#;Kaa$nUG0Qn2yj@%`ji#;_I!0I7i(L%rG?cA(8_WRASb1I2WOHJ z^uhjKB}p_AjYDFPI1~!3WD@(35<*C%knpPc)5F!q76#A;0R~PPEdKzS!-2KW7XD8f zG#5fYIB7Wzg-1jF%73R}B>?PMn+BR)n+Bu->iTzC9CTjlw=`&_`&T}21Dx$^(;(0C znlvd{puSy>9IB+3Wo)Ude-J6tkX9l;3@x|4@n?`(;{or@Bjv` zO+y0&y(Ue92oUv}G&JrvJtNTsxiw|+xZiC9 zoZ0v-9{~vv{=dtjz`2gKdI4}B!g0{=ANf$={aPyv3cJpD5K&+!_3yIahzWE=;a~M3 zi8yeeZEczaK-&LK1Ls2UYxALy_}}yboecS{AK-AyTKfY>&VI8iaEu0`LH;El3XO!0 zI;=^6DS4Zs*`L1#sxJm Z!CE)6g)5mdWdXTxIGBKd@?jO&{{m~s-HreN literal 0 HcmV?d00001 diff --git a/tools/figures/output/large++_Precision.pdf b/tools/figures/output/large++_Precision.pdf new file mode 100644 index 0000000000000000000000000000000000000000..f5698ad5b2e615d020b0198d0e1748b2af0ce2ba GIT binary patch literal 17321 zcmd^n2{e`8`)?_wV<=Ol!jUPN&U_s69HGp!GS3ciLPV*^T;|A-nGm8RW2T4#p@@ZF|4_-NW-d``LRx`}6DEOQEN4bCb9AMbpnxV^QW9oGw1f(Ztf`>rOeDe4 z4+A|jARAJEN7^jkw{jqL^u|Dgb8cenpxTrZQ!_-x310> zT2>^u5$IN45s+f#NrDS2I{^}8S7N!9SQ&1zh8*O21|Yg2>Rqh>b<5-nYg)NEpLVqX z{6O@B{?)83?aXAIJwcBM@QXnqghY@y5j+--Mk0g=Xf%{G0nz|0f_B{Cgk?fyh(u@b zgoG-s6$|uwRgS7wL>rPV9J$iBqMak)94@Tr2*e?8W#MdT1u@`8ai16~l%^%&x z^|;b(Pu>Zi+urSi#y=Lm&b+_uQTXtq?}7H06|LX=ng*x8nUxMZFaAJ1S){9B-xlJ* zzP+QD3&sZ<7#J@L&u?juiNovnSN&K(d>2iuxp2+fc%HbIYj282oC4#uNBmU0HWY%d zuhZRaZ>ZAdmal(mozisqt_uV0r)G}5bxRd_pV)^dJ%0v=m}}GEP`)fb38M9$Nm&bT zmuhaz3wtu(&OG3#%9cM@sep_4#Nxrjr~N_XyGSLwl$lLg`bd4HH^m!x4*Vgd*E5tB zrZS+TnkHqsFWhGLlMCqRf_n!zA4Vh%vh&wrz3&<}TwIg(~8GWOb0Df%uw-|^#?vh!c`CEw@k zPPfRj?I}NI$u;a0tQzSwAkQV`BY&IG;_VGhyG8AT_&{UsNP|%I#CtCrhMr#L%yvAa zM3;&;c)(lYX*?ntf*d|PK2iC&G)QS-?#;NkP_Q|-yRO4Z#V6&rQ%5FPxd#+So*v=) zns%xs>UdQTOftTC_+VdL#yl`ue3=q< zy;7b%AHHk1-L1aJOR^0oYlejM7v>@c=jG(WsjzesA+(Y)3W7z&&svPkkn&OG(}B;z zj_5~^)gLsNn|v+)Spc`Fz57mWlR2-mzrt-x(_)b;7P2R@k>cs5TwK@7#v4Ntbr)OY z?IVI)jt%XLX?Dn=-JNJ9D*M(gxH<9^8~Ti~`t#|i-0p+SF-Z;B!z!-ds3x`K2_~5+ z$@H9&$kT`3hcG;UtrgTwvHwa^CG);=E{~&`j~9z@dztp7Xvn)N6X-^MTzD9eqMETg z_u2&4Ee(Un>u!Y2uYZWfydx-BDwpqK$fk^;m5M(qsHyQ>**N_+mq)XK{!9caJ|jgh zPbSP6U~x7%cgEoIc}o<;&dCsWy4KD}%>JnELdGueog9X-srZP@-cj z?YMi7Uq8-+JzPqlDxLV~OK9uRg93VYnuRAq^*t;xsm|xCg|_JNQqG$j!yU$87O8I) zC#Ql-YDjsP@AQ-$58=JZVm!ry(-M*jA3Ikve0}i4(c3NO`(N$skJ^Rp@~8T!_V!1! z*L;(q%)G|Pa9&4Xs9WTLi`Q@To@0)4gr~5=nfV=bLeo^EcB`_VC&VDn)+fNGAk=~^UUK}CROvG zK;}JHgr!Q}YHJP^%FldyYivC?L~>5b`FyhRQ|ZSNx;anfAlBfnXzmI5#|n@qJ+T4yLh068|k3$85_a zg;}ZUXZ6}24^kj+A2{hd+C^v4%uv8B__He6#?#zRs#V0K-9L_EqyiN5DHAu>W-Zuz2VP65$1QFvG97gtGI-FO z_oMw9YU>V1t?i1colhUN$~?FfZ7KHH7SB4Re%8Fm>dIYKf^0~NF8?bguje!xwk`X* zIb=&b_KjVVEHN_jaL|96uawkf-4sTV_rN|y@$r`a6R~40g_r9^Z=Yrp%i+uvdZz6B z$oB56iTTyiO&WZM*e(ou?(0qn)a=+l(@xL-?7=H0$29TNTc}T!%Auc&t1*t#yGU7Z zU81@~qiD$c6s^;;+hj*xUQlkoV~dcTqqBSY*_}hL{bsmR#WkG|miGrqC$%I|Tw_1m z#I&Dxhnh(fOPj?mzW2D+^8pDFv)0G#YPanUh`t}gttjzGPnc=)SY_BPJwG=b#eKPK z7n!&$tm>&7wMS!j=h`{J=8oCEw~xcNMtkfWGHs_6^AWT32=49>KDuv8sFt4SAtOfp z^=%yOz(C|Vx6BS_?s7|HQ!*;5uYrPFo4sb@vF(kc0<^xv5k__g4XgI~(&|RQ)TfL4 zG?FwjP;|R;I}V33Pgg0~?ud>mUPOpc>V%v=#$nuf?oEkM_M=xkn4$j36Y7Z_^yrX8 z*)Y_@ynEF`UHVrD6?j4^ViEEb>rYQdY7a2(0dx@vo!VW%c- zeTo(lH~(8m1(``&hCM!f$I=azZVDtw! zp$na@96z%6f${-g($pp>74U9b0s_Cbg5?_cTUmoFPT=Y1F1b??DQkde1ym#mn3 zWFrPLY*0uM!m15Owtk>YBe=ScRb$v2{~tm6pf_rL7C^ zNcWXK)(9%%#hZ@kTHnt8`Y$&z%utgIK2cvtp}zw6Ol*QsJV zezDIJ-=*lpUCz~QXZgII2^3_B{cfG%uzw~_BAmb=$3Y|Kee4bAdHic}sVHTcsX@h= z<4^Q=jwn1;z3-@?+?{L0mhSWJRi$^AU3JMJ*GCg7uByt@J^EVdqwh3*c8X@a>9L{5 zT6mre^AT0Mr?VwktUK@Bdxh*AIsX1HR-%(S#~kG(BL#oHm=P`?I5(x@c>I&G68*DD zYRgJk*l7oDZvQ%&Bc|c#2<7dCwM61uwSm+VOH8VB-#2a9Bbl)Qw~$5t8@D+0pP5CM zAT$x`2ugFbn;WGYr;0z5oRAiRGme@kdW&_u;k7LWJWW|j0$botmgaQX9H=apeQ9DB z`Dr&X-DFbUfDFj+|BVbZ68C31BqCM9M-Vh}3j}|m&hwv3hG(_Yks%4+oX5K2-z!MV z^j>B{>%A;G`fa9k;Hekgw|RaGqwdlgy$KPGw%g6T`->+|EA;OkIQ#N!;dZ_E3L0Xk zpNdDZ;dl0_M6t{jPzJnhzF5fbCdo1{VmYH*IM4K1^3~yO-c2%UqYCtA2Wrj^_NAzx zo@|*4-bU#0nRbtvRd03e>Nq^^q~TH~JePl_o!(&k$+&vmh-{;TCI_*&gfimT+zwMZ z-hG6&?nK2dMGFVTl0z-q_IcO3u%6XE#p7kT%31Z_?m6@DXCTpTxbA@cFq9v+C3D}T%I^!oeSxT#_8=fr-KZ{H6L30}ev z@E=aJpFcfATi&0SZ=J|~G(^13DKNu2H)IA1}Jos=~fzdL@m*NgoJVZ42(PNC!kvfuNU3SqwFqbY3*sG_EZYFKB(}%7>I&D7 z>*ypASl>y{-*7=S?e{J$vrwOj8o*3NChX+D7#z~*rf#g_Nh!bgL{Heh?|bNS%fq#9 zb2=4OMcgHw^M4=XUQt|~#rDxD5M`r`6L8sE+rG`cRS!xm*1madC&4(W_U(?MWN`EBp zxWH7hf6AAi)n(og#&JE_GAr@qcsaQ(I?yqwTGt6XG$~ zN_a0YR)6lejlLcHSaX_v1nV2H-Jx>q2O?mUrZfD(nQLXIQ!EpUbxe{!c6rh=+GWj+ zz738wIci{%6Zt(+tw*kOz|99Y(EsxXwexkO!yCvD0s?-MAsX`+%aB5(8u1l^=F);B zrLd{=_Xx`*UWHBtuO+9N)TrH_o(EJ$+OBrbTh-gYV3l;TdY1Um_lu6(bLDSt%r>nB zO{M+cV-=n~jJXqbT=*r^ad()>} z-flm7#REr}ryh43qfQSv-=@i8`I=o%8ehnu@Z|cv(dkksKJIRGq~jK12i8b z6W-YBof=qdH$If|w7hYKXdP%$jD+3U|3hlbV`1;&eP6A-PqwKoMVBzoTo;ueMHSE( z8X_Vx4^_9?M>~q%QTXaV;9XlR6UwnYY@09zt3v;i81=9omdqo??iG}&&xqgb^3Pov z($Q6%f4caIXLmVsUjNeEtc-KYxy_OIi^Cg?&srd7WxvM%#eR*(A$1YzMTpG|idnI+ zQR{6mz5rwC>fE4mS)n~jG!2h7m(ZNY=-bh$D$wa|3*7N!n+A=!4AXU%1i9`9EW65S zZeccYMelCffTE2s6awPUvrgfvO5_$KmE3})fB1QsMYrT*nbyeiB^>ID;6%185sRp( zFUoP)8d2}^VX`@}Ha+@H4YDt;rEI#b4CaQ&mQ1@ANHTV%K+=vV=I>&zXbswtaJX9f zPOT)V_mt`m-vNVi#HL9^8 zls$qHXBvvqh#Vq6TaM&*Bel5OF+S6Ghz{pN&g95!`%OZ(o%u!UbJAhLs4A#fkIn^?k;WG$5qpn-}epH>(+*hqVpHu-Fx4YSuxnsOO-*TMXf^#rZ=x~`R!_s_{_m%ZlCjyIX3TZN_bYvkcdAL>V3%JxnS1m zhRUZDaXOMl@PHeqBB-lU7l>c6_mq0*OGIf49R0HE_{)@KgYg_f{#R^ne^%ow7L9qW zzi|C+n4r>Wr_ndTOJlgIci$;)%OH|B(2WfYU*unH+s6rX6v$m4W?|{YlfxZT+4_KavN+tnX!B6jgND<~}rVZ_9!Gy^Oovv(T3k zooJajU0Idy$0)L;Q%8*H=f8dw)iC3zhmLsIe>*K_+k(8|DPzwOk(xQ?L1kh2>;6P@ zm-+PPd7I*&oqw$DAAXcjFty~X=Y3~v`=k@1djtM9u&9w(k-u0V)nypAa8OlzgC(Dj z9RHd;N%DI3bLmCx-KQ6ME~!cv*Z0$n*SQaqyys#u^JX6dC@(&(bs9&>oW7hy%Ne7? z{5AI6o@BvA9n$NfV!PrjQH#^m0{f3rr}$sw6f-C4<*C6A4eESN@oNmfWtGlw%}sW1 zTT7JU%?}G&=TGD~#c06f=(vM;Z=%Nv@*9SKZX3N@bMoGcJP9dhClP@S*xbNoMPvTr z;DC)(CdwnJhm8J+|t!h%n&5Yw; z-nD&wg7L}}@pJ+4h@@}l9>AQJIV$?h;xINsr z>TrycSG-PhlfcOzO}SjoGl6owLTK;s{Q_p2REqjv*IV*--h*bkTbOY}{Iy8-0fWpSv9!7`XkwuGRf+=Z z(1ElUJV--&r&M_*3O#y*_etd`Q`5$k20CUe_I$q5b0%mXBAuyZ|ukQJ})sl z^4!9WGl#3RI3$>di#hMnw7H#`>+|~i-05FSs*gIA*X24y1_~N+>j(6fb(yCsH$8C( z`kYPknb|=n5B>06+v6vGt(r#coSUh)@f`bkuFmG{l?CD6Jl}mA4ABM#GV(7J5sfQW zN1U%`k^@En8_SjrOmA>DS-1Nk!P50d96{lV44oVY82rQwoK)srv(GRl&%QZ}H7^Zb z66pv%wr#fEKXa(4_gTn{k&?uxvfTLDw4$ByhY$73p&qFTTTC`Ldb36hp3xM9vl?Sd zPdVjAwMF)3C830Lf_!f^*J(GYG8h}{yPGOEZr^)D;8F}*^NTmE!Jm6e#?yp-ej310 zs_VssTx-!Rd)LS=GhtaXH>>58g08Zh%yE8sD|<8D(^_UkH$hovM@xuZS2vQJt(hwv z0TWg=TS=p^5NH8hQu0=A7Or+KBxhGRkUQjx64p9xP9i(-ppMWm;bUe_zzy~v+Xs!v zF8jX(kqA5jE{FkVb0h%|{vaXO7mVDGniFQvIQ=@Wl!b5o=vhKR}HU!M{A9!$s+lKTpi79+yH0G z;`^@>h6~HMSwKEY0tzq#bpt&JqR?QelQVNsvLXur48Q(r5_mv?Z`RI2hG^qx1%^Uc zi)7`b3nkRdJjraqkr*J$H9W0OCnGoxf%v}>VE?D>pfMtF3>F8+0;Es`Eds|8aA4pN zD7Xk71xF)r;3b|MuqZUS02+)s1_>OpAVveU4pjJ00k-xX1~3W*sFT%z1#N)N5O6#= zS282N|m@asr?>G@a zJOD_cHi6QiK>`V0fDTYt7{m=kH55ZsU_s|lz=4MGV5(sWpgf3icpw(^avmNu2Az?i zfCF_w3`B$hJ%AU$IZdtrRl}5hw-(ERP!` zLO=vrw4hTQAdW1SRqCLatU?ff7*G?5#No+R08gvx10#~ znm`Z$Wf1!iRe;J>O@K6jJa9ncp%@a^Dj~});2@T8Wcfjw@K5_tDM+Z8l>nY$c&HT+ zAmX6+&^zo(17L8V+N*ycZ6GTGqzu0WBtirW+5#P~XvS*1su`>CvT7}B2Ba0g1kgA6 z8Hs{H8nPP5$^qWt!Oj3F!bSlMI3OGhFr*o40&G z0Tc)Z#s5P0Yfqdia59n)HWYO@88-(hZ8(^cSosO93*VvLJ3bEpGVA?F;G8bwrX!-E7*8tTp2W)T)bgF3M~*)g{g!jk){ zEM$7@;;^(%FZ^Z-u1M8k2hykaO|Z#sFFr&fS{YKu>x(-iinXP;jrksX=^?SVAgenc z@A|w?LbSD8{ekSRCUaVtQuYvD9oJCpgu zDPAv0YD6Cdhyg8qYH&5Xu@0UmrwMQO+8Ici$ihW1in1o6 zrOfaa;k63LZHezlSrXfM2;b^pm)!PSP3S7|2HiXu{vI99mwebe#dEj7WHZi)H&S8* z{zg+Xs3kKc`jM2=qiC3@b=Ot_hVs{zjA|mS70Q~SEYYa zkr`@K(qZ|W$2YrD7iDApw_iOrol0HjL-;u?eEvaS5$(M<5-6D>u2DYOCZYyQ0)q^_ z*;oW~PKJ7lLEWIn)Q1HWrM+`#Cbh+ZxgDjHvQA136z`=s$xlRU7hDau;bN>|er02R zz;vwZ_U2nG+VZj96AuG(+jBxIF{=m%!I=zACH zo44%y;aa&1lkIS@RR7@hJLd9X!H#YcRa<*iKQly1DKPL62-bSMjb6j|FKfwNj zw0lYQtjpaLDfD=n(*|R@!SU&@W`&xDa_|-eO~n}VS80wDQI*+p@p5BJeUIq;30rp* zMn=V6)2H5vyu>abu^7SOhX1kOH|=?afmdY9^QCKJ2emX7HFJc_WlaM2ccju!OQbq5 z7!+51A_%0%yH#&z_rJFlox0h=M;TR~GaAFec#P)s0XoBsx_b`h^+SeT;$i*7Z;IAd zJeO#%SjzVAJs$}FKyiDoKr{}=h%#v3GD+J)YiQb+yI&*Z0)KSx4pN$K!_BtWC8NFW zcsaqD*Ev3)zS`=>nfiCBeAf3o=`VRLH{N|JQr7E@n%KcMHdqEqOpyPMcf)(4-&d(t z&FLR43yZ9UJ0E{HPTW9(t9C7P%2+|KftB&cNudz-fW8(^wRK5bW(N%R;y6##<@u$s zU8u9BV{}uOD5>wBQDmo#;W@#qzx}91zxoyFsyD8JE&Y@`U3kyE=+sV(ed=Ksp{&E# zpTTfuiOz=SYR;K1i=+i=i*KCb8&J6c#*X|OS1bT40bm39ol`j%Ik!l5SfluG3(gkL z95rj7>ccNqFka|>aUp2@!sWZ%uP8qjY_+lY{tidPGK_3JeO@q8^?!k#5#g z{eElM=Pb?Hk2|B2C-R2^K5V9t-Dka?G+>;q+a3P~7DVGp)qz7CgV-La7$6-R5hNST zu}5iBL)+U6J4VkNY?I!XCbIn&&t-nd3I3*(_SyzrZva<=)59P26phAdL0%RLhODdr z*r;MG*~=ms7}HA0*t)r-O`y$BfRp-a>;WM;Y72yuxnqW@wfQzn12__XA9e4^rfUg| zTQ;C(1Dgm94E;Zng9UpDg>hxvpJo~TrJbHB{2 zA2DKI;4YROxH6g0rrEf#yM(Dju<%F+eVHI^s|KTle!MXA!S0T^anT%?qbYuRRAo&2 zD=#WZj^kS*sYa6wLKlssc>ncQs+Ork{( zDJrcKdUVE3>)0;yh3Zk4vcbSTP>_OFK`*Z*9N3SiIBhMenE@q&$6R8cXO+F#gcw&~3Kvx@Rk(3}nV`0ma z?G+wW+nHRFY1Z_R&G7T=PVIMRFK;jq8`x`T2n6|)S%lSGbujzD&l6vMd=6KN_+CSs z`t%oKI}@ZYr7}Q)ZG6{WHNXm^^D-{ZGlzsvB5A_-JspRGM{c73GxJ zB!GJ7r9VfS?R|KW8uKyj`&bVR1I8*WvT1w43qn9zyoRo4nOqd@W8_1QS-RwUmK{gG ziZj(*;bvrwZXTqgO+9y#Ej=_5H$X9(CgsGuoo=Y4{zSPS^U*yHkqeu1Rmc6_Efs#c zQ{C6c&N_L113x5oBb)Va)(b%mW!Eh-R23ckbC;~M2QA|GFOMwsQ9atejW$G@WH@vq zGzBi8SNS#~&$fOLjl3{OdD)l0&q;qoS{trxp^i;;|Kv5t_UzQP9c@v)MKsp=JHLnW z^^Z5cL9*DC^;vcAnlLGph?i-}H*Vl&+@soK#%V(*6^{}sb(6ru#}xJB%&ro3_n{fW zG5$&%{x`2@v~QJR2)(5khI3{5mMMIN6rdB9gPd>^bUN9gCgRXJp`Sl963u3EN{Dxl zaQ=b0&$umnH0<(cq?Kxa5O;odyB#&?sWR5S`{}UvA~IYE50Twhm? zhTI;u&hAXDLq+{tC4H{LfLBXOLki^PoT&cn9u-NC4EQ2a}HRuAI`L(CHtLw6H>zWtU; zu^)(In>I0?YNfDDIT~Fu(%@V?SF%m->(9+8N3>(t%LmN$f48K6u}1hBWLQ>BLqlG7 z4KfS~?;lVYSUqb%;bpXO4N?gJ8S=sSS5O$f0t&Bd`ggPvKmdXmV5lJRz^Fo^z(E{f zy%6sF-|=AB8ax>IWB~>LTX-PNajC)z8i}1{{`ra2583`c2>>K{|`W40IFlaNMLZwfE@tnU<3$41qd7p0AOfv zzXJ-ZfF2r@2RBRrJ_KTb=|F(qe+jVFcTiaEU|}EtC=xxnCNqNu7&2f8A_QUz3OI0k1FDM!j4!tV%7T8$X#j1Z0MG{oES^A)2>_DBK?o#3 z50}e8Kq``qoTCX)Jq!Q{Ayg3p3IQ)rfZ$mO{ag_m832U(h6D&=AW#_8Cl=KC1xAGk zL_+!l^#sA@WN#1|xyzft%GM6`Q1tb1RMdr8X!ej z)&WQlehp+<1HFKBVof0H1EdqHfvhVlIsxg!+5pi-egGbOL<{%b@;qF(I%){aXy5LN>-XJru8fP_{?Pz`M8%j5TN1LPX3 z(p)w~R=?!@*Q2cH?CJzelNEO{0WJvU7Wf4=!OB4&c<6x%@rRrKPFIn14G5~O-X|Vx z@2$g0M&YcgBY1+Nm&(S*j|CngUiJH2boC3@n=r-mGYW*3BLZ<^{L1%xDiQYfuMC~i zEgMTo0X6P3Lf`*vjzTOTaOX+_w`)f39<@4YFQKP?RPW=yLe=QocJoT49`4(2O zvxM$vfb9!fdC6Zvp9WVG$P47kcR|2Z`gawU179*jS5XdudLRz`DbNz0ex|Pad=u zLO*cgN*=gN2dsnN@-UF~ur3c+Xus#-f#tX^4_I~U^1#)VU-Q5d@O=NChlaLc=;t5$ z5GWkvOZ_zufkOP=2Pn5L4~Y>0&d*=#g1eE>=J`u|NGy6iZ`cj`#p8gL{A*nV9s}*~ zzqALg7D4{xU-Q6?kKbtlG((2NujNo8;6~4Hd1wS=sQg+Eja(;Z;Ohq$7k)1%f?t;h zzWiTD9|i|q4*InZ41PVoSU~mf<$$_?dG%{uAiwqHL?HX;*K#-nxPJ9}9vZQZ794(k zdpPj#3x2B$cIEZ@44V7BJtPVTP9DGZi^M=DiC^;&B82sM1mrq>K@iZ;JpkzEU)?3F zA75}pUngg9`vE$Q|56u;09N$6JQ2u}{jD4U+#>lc54nMk1BTbtMdBc<^VjxJh;=#y zz!1W^x>&;6xk_?1vvag^CI3qeEjuqOunxh6HJzQI1(a;!Dif`NjSDDUHf!BTX09ai Rl7++pk-)gQ6^<*y{uf=!BF+E+ literal 0 HcmV?d00001 diff --git a/tools/figures/output/large++_Recall.pdf b/tools/figures/output/large++_Recall.pdf new file mode 100644 index 0000000000000000000000000000000000000000..2671f078b1cca9102541c05f5769ed1d1ebbaea0 GIT binary patch literal 17490 zcmd^n2{hHw_iqyF8p@O;x-uoxov*9NJSOuH4W?@dm#&)(nIkfku_&aBnTL`nGK-Wl zMAAS}p-3o|cfJ=T`Tgp>_1;=VeRTmNk zj;24D+~j-%QJ` zQ+3mw_a-CjPe)``4tt)G&(_?<5SmL?Ctr@>-!HMh#&CA-i)!VgDf<)G`_@@;5uZ3e zm>PUvUf)pqzS^ar;K*%{KEFAG6MEH$o*nO;4@f<6du+bu!Enk~u368X11G(o6K)9W zsdo&%SM}p85e|<#%d?Gi=f}Wtm&$Mv3%ApbQn%U}1}CS@FabFy*SfK>+p7*>Y{Z14 z4sx<<^PTb;*%-EIJM*yz{99_Pw-?T(&OQC{Y6Dk8_7xjq4i2omt#6PfCKay34IZjsWEU>5u70|<6a@)kimT_GE>l_1$f9{}o_0c@zq0)do z_sQJN>W#e|i6?J(a-KUm&)V`T^0RlmX6yqxGF__*Jn{Af2g9gwn#cUhW>FrP-06o& zdChiLJ(6Cd&L9UjQ#RY7^VjN+@7>n$Qh%zusqv_`99J`^I9C9N=vtrl{W~XQ;2-wx z=-`SJRFk}JEH>FIpJH3VWOGKC;1X*WRRUA*QB5znQPZw=uwC_}?mmyREw8Q5xiTG< z{W|;g&Geh;Lw=v$ejR+TgZCR4`2Ka|^pEk!pO2iVmb>GUYUFwM;srDV>2#y!;o;&s zCfSR7viS1g7MmHa((*WR{}>dkdLd~#{H;!bl{Pg@E@N|jn3rV8nM1+4Zszg`hFDsn zd&^vg|7qgPk==1EBfSWD!FEL6RHQ6VP#(wJB(&4=m}XgorHA~>=I)Dd){*O5c73hbf)f`Dk!Yf0>SasXI{h$|BnVoA z8Z15J=2H7Sz;Z(-&Zwm4Y!QLFW!|r#A8@`#&llG_(ng2KZtxcLS#T2#N^Ubw+Gt4NDz z7ExDEw$t74Ez<{YG??7g9QaYZeo-(ccf=J^H0R*uTweAC4ofA^8jun`$*%GC@OXU4Mt_x*!&K-jd zL8plD_B~NX8?}mJ5uZF@jr1J6DV|krI}&I;$0m&WwD}6NDl;Cd$;u@K9y|AGv%T3z zE!(;SJ8cy$Qsu^Pwk0Z@jHwrkva@H7RN2FD{9CxkvrN_${XkwlCdob;8)=vDt}8vQ zH)D%0XC?L&mowZz&UU}w_b^P>roAd9r|g@%8^))O zot+-s$)Ykh<*@6PpkkMXS^Rmgz?U>vk9iSrTi@C#s}0!DddrAw2n43)nQSZo|Q;$4r#}V97 z+yyS7fviMn&n{7N!qACb32>R&T@6^~1g!5?zn0vCUp-lC8>QKtF0 zUIYCxf9PmaCch3&%#t%XX#J7+nG~JM?^Y%5Tb_yu)5Kee&=+ZJmh@d`?A@t-I`EkP z*Z8cvs$AE$$|%VN=4I=iDcYSj!eO?9%K>gX?a!qg=R3+mR2tKKarv%=-5S5UJH9hUp0>aF@55q;zn+^?CODkF0bi|RLRjuFEi4p;O0sJ=c68~MdwV?y48sb?lYe~;J|%_{mMIcB^W zPFHBV9}D_6)9n8CA$V*cdwbYt)?cxRc zu*BrXzHFzQA>ka%7crgwNSv|g0i~iCqau2ad{o2Ro01zFk;iIUWb)aiVjvF-L-Qlk;Di zC3ZJ=Pdc(08X4U%y!a{&?tsd*%D7;x8-FsgTK_mXzNCe(%ajWO!imjK=9AUPVsBH8xG7wU~E;6pB606B6&`=VXlkV6bpV_5$Z%3s|1#fWS zeF4#DyQOUPJbW)k33XJP!;RRGXCj6X+rRFr&=A!AanJEecSwnTNyuTLTsSc_^vSW~ zdxApr_fn>UKizq#R%u_GCt!#XVlU&D`q5OVstT1 zXX6(?AFFsy-xI_3N0A}by`wWFa}3vXdQL&c+M<22Xtq&}5y(J6A<^h1TaIeQK$+(7 zLrZz6v5wi?8`H4kqaNC&E~n|VvQrT(E>LKVNKQO)SwLqcHm z<&r@x{9Pxblc7w;bvtz%icWcXH{#L_MbBWeW)6A0*gzY`~`t3j?qFjdZRo?P2t3Tw(jP{munscJym7~tckCjF&lL;*XB-|QdnBoNY)*U?=v*!(=X1ryi2G0SpPY? zqcB-k`5?4xvtXH*l8c^;X2Ir-w=`zSO^p4O64JE0&~#m$2fD3}*=9Ki)Cm~4`y!6E z^o27$W-6^)D~oHFV3fVn5qVLD<0JjJOI5YW%V6JmnjvfB7pw3KvdvfU3y1zYyU?Ww zeS|K8&I;}3M&~A=70jU|X@C$o&nO&6Yny0xf%brKW2U+|E!^40iY1F5mFc?m%I+XB z<~1DGI5buv13Ch%AOnrW{hbck7_F!w1e4MXAz1RE|L4-dPli{K;Yo8Y!ySq5ROA)9 zPjjG6o)&4&O_mKj_F|cv7PB_*EUPvdmDX!3X%^ur9zCwozj@%~)02e^Cht`Ab{~Jd zCzcPtu}3SGdn%tUFh=nUjRM>vkzOVC1ZQCn}Q}_X~ zU6+ndAD?8d=+Diwy|i63TrR50b$EEY=O&8(Dnft`UMqxvLgW7~1gxGGi4M8`=`hEg zn%G$pw-giE;~fKj=1PVRf_p3?_`gW6d+CuQkbh2S_7V0WU3Jhmr*C;2N}&SN3idlt zo8EX7e(csYPW&u*MEK766A8`IBKs^ZI_ zI2YZHzCrN|ei!dvS$rpx?}O_M|3|V@#dvcRd)oQC}nhb2u z+Q)LYJtT*`m`@&f@9pzr&BM<|MyqfQI-joO8unjS96bVQ9g^w&cbKmQ&HJ+NDH{Sg zO6Oga9Afg-)ln)P6*%Gkk8DbP%0u<%Bd*+9`{A_13%+ZUmNdOApEqm>vQ=zzWtLsb z8Z0Und$Vj}tK13n%scVf%n&DcMdmA~bHhg8P}GgHwh%wCY9t0<4jO9Js1$PGDWu(< zJ7#VpaH!sZb^oJ?hopql*r%6FvJ}o=XNmZ3q|jOBrqUJs^3+7z!}bRIf}+EF;xCkn z1hU`x{GbF~a^}AN$L&K{-w+1Js^Ra5kT*;ZqwacMC_kQRbE(+qP|Al64`z0U%&9l8 z&m=AEn4VmEtuY||QQJJ&f)tUKDrE9Xpna;eGpi;-Ke#+)0cU2A@0 zS^u{Ll_yQ{*CTbL?pcQq4W2m`D%Wb2I9KE&edq4Xz@e<&e(ej*tI`9UqND8Y6Uz5HxJ6nF+8r zw)8O3kc0Afa>6PUCArj@>hIB%GWlan9a#3Mu$a(?Zg@nm$7H3zagjSosq-%PrV6G4 z%o@SC&5f&2v{RMr(`JFbhabA3P799nZV?qxNy z=X^`~$F_2q8zM{Y%7swMLGoH8^N?!ZCeF&%unkGO?#N%Sk)w1U+qc1Y;6MdZZQ>~F zoY7md%u6p;$L5K&Z|~GbD?N^q)Av6O-!7O{pEJ*)mJ|O%ZbKABP1i8>>!-L!%4K=X z(YHOWm-VzqUF3+5b{F1uh&?Jh|BXO13i0Q44+$qu!u;=0x_2I*)25mls#)+^?ll4VcNc^dm=%ly13`12z z4++mal)0YABheXj>uctGrfz_&H+@=_C|5S)A)56wn?6{YU-a}by!jG87~nq9>Fzu# zYtz4vF4pyp%~i(7kK*6X`;1SFlk&cBlPV9y$wXvj^c}udzw5gvTWgS=2)BE3+}ejJ z8jTSwl2b$0WX58dCZ~{i4*4;rp)UH*vr}TT@g=1b@75Tms&BJX??X;i(~I4{vm?aN zVioSJW`<(^VN*~!w;S3Nl)O^|xlC)m_Q^e17tYZ;SI&xMPv1eKF;se9Q254Rh8zER zIQf-UNaJ0Z*|N7UAIzFNa_z5sTm8wYoY|Mfse2!%ZN?2(=GMsd!8&7nGreAs*`~Le zlP>Vv)#lD`eH)+K;k3E;nifG+#5;&*N0tkU+nC?V;!SRxM5klX`KWD$_6XxmhRpYG zf2{hp^?99fP52vh-t60(@47N72TrJ`9gw{JAWu5y)|2l2;WuzLTF=#I4R$73TsxQL zH6+a8tY%B{_}=0(=D(zip2tAFwKLOipsxJ39}UU(@G1s8dKH5d ziTww?l0Cbd1sL#3Cz1oX)?0iZ`*d@?O!~|bF(G>frFGhSJKlr^a7&wVCxuM~a~ZTg zojaGkXL3i1kdOa;ewxjVNl(hyF5%T8ymvbGN@O0duX?=pypfzaJmm7RXvW)VGo-KB zo9bPxr83O%f!Dhow@irQbPT>^P87v=7({&-+o(L zqJT(QMK@Nlb&>yMFz66i)*?5({{-_}Bbp-6JRXJ!roGOxz1cvZfujM2lUc`xq8 z)=xT_prf1mORi+o&nTN6JLoYaT|LG5TtiCvVlc_db^2;=?wZ6W{`U=oqcjQmCIFFzdj-O6u7Kk_E{F>m$l_D`~M0rtE>`*Zk(+(FvU^>erNvj7<|=j{8uJfF7Ho_LNg26b8Q3gb_AUGP+CN=2 z-kdz=@(F$q=H#kpB<*r94&#}l)fj!y;0`7ITDs1cw{0I19)ymu9c8;CBrYgyhxZIA z6ks*zkKOVLEhASqJ=$i~;oVTf)N^g8>>I1tGwo5nw|B)mdnFn*H;Nzm-k2lkG8wAW zEs6Gy;t{u8qgB-ZqRvL-L5@|H7hA#S5}Ty)sJ@%YX~NDLhr3z`&ab#86GZN$KpK4GL8t8<#N`dSm#1MiY6 zQpZ0YtU6$1$$eDRSHAX1ROnUx-l)v{(@_omxbEIkOSN8WH-T)yvf}VF!h)Q+_dZ%V zSdx3|ZV6reT6#y*xuQ15F(y>PT*x$}yS&3HO{4LVW7y{`rq7&?M!D#ww{7`uvmvN?IFx|-!N*OziqVf=#UMhW zQU9RhXBAkK_<`^*1q!UU>iF~obLh#O%kJj+=cl9}MC_;k)E=Dis;K)(_@ueI>_mA^ z;-@P`8xwc!>{mkF)0VOxYi{u7jehQ_F9GL0h%GzjoDFzmdoO&;`y5UUh?%9q3Ggd15KCgMp3f2B*PIfseP7O_PFLc^r?TRH<*+<)vJbfUQ6gOfla@CdjB2A~2+0v`N9Lhd&h zbsTL=3fY0`sFOq>Aj1Uw`6n%rL2J~5bXR!7!m=tF+1=CDm#D(XrdFC2I!OomvRC9fd|l*42FYCQNY}Q3sB0$ z!G>ZFyon1=(_c6 z;37r>-!O>LI4soQKLyzGG7Oj~6re$t0v7ZE21CH{;E;vKq2PEl1{4C0NTaDino1QO zhK7g%0T1*K2ap;90SoHk;nLtJga?)sH3n}D9;iJIXs0y5c%eRl(xE{D2@1dfC@c(` z4Ma5*LsVeF;84JUj`84)!xBJ!5aaNGE$D?jJm?HMq(cD*+JYFc2m?j{1;9^EZ2&dG zfW}xr83r__mSIsqi!g8;22c&sSU?%IGtdbX0|FN02C)zjL1ir%6bFc-vSo=nD5k0q zG(QYziA3V?)Fxm~OX>roN(68eYEBJk0(EXbyIM#<1VUXP2!JwZ`Vdus$|X&JG=M5N zpz%-)v1^Htg&uIwlyFr3L7MQhf2bD3RLo)k&oDgH3kVQ#P&rhFUF-k|2Wq|a2hs+r zB0$QpB7j8%3;F^BFKWh8yrdaR@q%hCXa=MezXdQh^%;qRK^n3YsLBD#@L+?06k)Xh z1P%xX0){kWS%58R#!|eX5DSU{rI!bYxTSJP0hR)!96*6!Q2fjLu>Ab31*bYAz`mjn zr$YZAWe5kaBz3Q`gai3eQ`T^x5Y!YAj#|7>Vh2Z27gZQoj8TiblM@`QZq#}t@Rl#6 zT)=&gg_J9_7f?%o?E=*E3b-w@Y%>7jpjINofe5K7u*NP9y0DReSCpC$uEs13{A&#a zFVmm!0#J;@4<<&MEX!zX$>&B1Mp^L%Ns%c%x2+Yr9L{5zo%;eN z^UvMcW1GCCeNcta6$=dfTnQP!o>mH_v?pga{UQRD=2Y?Wm zS<{MBDI&=hr(vXHrc7nKx8=`v8eH%h;bV)@b$h04iWcfSm44)FE28d#Y4eRjyh#E1 z&gTAY;j0j~${7ZQ0UO`nd|a}+yQ7r(ncfqEr9KKy3b^vA=1SOpcwRu8B&xg5_Xb+< zA!~pd6Ed!FP$>U$K|po5;soi{`=DBQu97~y-D_hgWi%5P%`U@x3BA`6Zyi;mf~3Fn zmXayUAWWF6g`INaA&wfWi-cX<5%mroC7QC!JJn;e_*gT}TqH(z2>w!EKddDqHSXSC z=X-H58QYF^;%pT!Y}mD>TPrp6Be>&G@Axv^clhM(X}&5mVx!%jCy?0tSm$Zt>^9f( zBYige<+u6^ecJIl6{Y}#Xs9JyoK$&ogepe zoB1ycKTezNScQYD9gn~o`*#lNYmpk{A!owD%a1-@eBh~X^sH*m#ZgH=XIE7ivJv5P z?_9-etwU_bf?LWD7ntwKJSk~zmgpH;86GJlcccIOIRA}96v<*~UbRU^RPrw6?EBX` z(q3|Cq*D>qGc4DCK{*r-=FHOIegck>DFq#YU1IcM_WdtL1PsXLh=7>Vewl zdgJeBP;^IKA~G1Qw@qy*qf>NNuV4F4evR^IoMHaCC_6#+YR+eNR@*FwJ4$E@xDAyP zyhoctbK0{bs+8;HbhN+jOxP!3+((AauaL!OkFeW*6-x1O_(gl z9c89FE?&1%jy&V!CVP8b_r4Degzh@_TBi@m-DxFg^e+CuzFQ$j-^;s~-udKuBXuu& zO4 z#IGj0-C@`sd~+Q-jmFwX167gzCZ3;tKhyDTEN1DoHyy3&UYT|5iR>rMsoD~SPcfgh zQS9I19}0iJwq%QV91h2hI?zr##@xbeX3>+wqZb}17T3Lja>cm*THA}#H{I@dC5g!w z**+6r?Tyb{1b1kCHuX3XEO#L%(S1Bd(d(u5?j3D>uyoXJ39;+m_3ualU-!1^fBfDu zGs|1J@&31imsa86D(6OEZ~XP8P>HZHXm%lIie z#v7ih+v(zk4|AF_Xj=E{o|V7-k}T2EPq)!k#INt6;iZJfCmf**xc2?1ax{ zdv;hS&oElg3G7*g%H;>brK^hn;uQwyB|vZhdbvMMU{-)u?tv{z6gT5y{lrPX=CLU} zsFFRfvoA1gB=GbNp=Wd-^4Hl}e|w7~VcCY(9rrg*AGD9t_o--bh^I{0)mG047sCQt?(wd_|IVQQ3vTz7Ten(%q&M~F z_4ta78d*9V&}!XKo?K}wzoFlN0KquhVKUQy&K1}~f9$uwmioIyBES}eSg-7zX|c{% zcp7X$wFA}6n}kR=MV`qXs7?(1d_bTK)&058+wyTh~3dH^|I3XHWrVD)Y7z9I%YKVM7beQ59elGPj^=+>M zH@xvbKrg@biZnxk@M$r~AOEt9`NAqguYzBLbHrb{f=1&EAcxOdw#>{B*c;UZs>A1} zd!(j&IeN?GTE-0sR1VI%Q<+ef;83y9i3(1%Rs|* zhvge9tx)>3fr3xm9FDCIBMXQ5rB|gq7s*dgo02_{*!nzHE8`)r_Wlu?hIrhou`1e_|)F)P!0mCy6sz+Y4uTMTONoOexK4BcB# z2OC}=rFyl44LIR(;`YfugKBU=;N$BS&ueG?pN&U&1>B-Sm_WJ;&c^Gtf|%uJf6 zL3v)Nyan{JMvu`m7K;Z(Q?p7Und+ zBYamQF5w2&zPYm3nUjI#RcYPfiU3YcF2|S|nw)(j0dMCEC$8V=>DkUZ=D&(d6o77j ztYoNvSuZ5?G|03H>nk6KO`WpMdTxDQ?DWul&-!}|^vvP%6th>CBU0hwCRMMabM5P% zqmhBn=}!BK^*Eaj$s58otaY(z?h{^9d{2&D*w7Z+UBqOYxA9wqX#YsVOC+~lc@MF3 z)99f>*+hkwyo2>Z>|Fc0ECuXX_9mhv%iLt~@L^Tc^OomG##_;BQJ7$L{@`mD)7#f6 zutgN8M&igEa~V=+DIrFY*~n2h3FjjZw51&%j+*984#n{uIwmQ?C6%{r>NAd(OV1&1 zQeM60J89!*x02ZB9$Lfgn;#E)-*1_$HMD5B*h7-Xb-j9C6*@QnU;ZJ?zXd*jd!S zPVP)OY2!2N>jCF)@DGQ?smcThM{ow?Z7EgRTXcI^2QmD$Dq-VE~y#n#!#ns=18YWH`<{xxgZkF~C~ zN!5%i9jbRJo+_m``TBzElWyJGj`u0T?3XMI-|-*M$UNZORFdqS%v(AZ-i2E&+cX+(Ft6fYnlN zM0X+peiwjKoHT%)0S*pu9|$Xk5IP7J{tW{LxD~{lWmv&-+vS%umhp2L2K)!}0B8OK z{JSa}|6jnrXu$1X)1l4}g0%lXfPVqJjscRu;1-}e0M)?=5U2{UI23@w(BQ%c6qcYp zG^h`*mH?y(!~orapuPVRU`u6CT_`NTa|;MC7QjtZ><|Y$2{>s8Z3Gw`u)c5reZoS3 zEEE7Fh9Y2LAOKJk0JIQr3p5`OG!PHLYY@N-#ZWnbp>W_UC_HEn^+}Khc%Jk^9u!k2 zg9d0am=Ht=G$|sD*IQMW#`qKxk};fglD9g+XItL7U$IRfs?&q(9I| z5P(i~4^iQMXlhUw3(cuAftrH=!!F9~XIBdepzsS_0OSyAMx_c+4v7;|gC%h-#eiDa z5|^kwU;vg5rcGrZqz6CyhiXA{qAC;UiOSbS<$!tz&oFA)Vh02q0lpLBI5>(M5d*eL%Xg6d+x}A+SJ(i#h=5!HPft=r~m` zz~+cs7M65kDTbH=y2Zj^1`NAp< zG;rYwEZpES3bpydx{h50d!ePCYJq_+s5=07DWN9ImN)RcGyzrMXjmLu73ki=Sih<+ zH?tLx%YBqTznPR$QULu)SzCgm6KE)W zUoK`<7qT-jO&0yf1h@ovx4=JO6D)xmRM+QUK3i};hh(WotY7DS_%1_tElw^LN4yg) zd?s$be0XGk=uXnPfX_wez6idIRP~=!A+)H8pWiK}ajUBearEdjGv})|4P}&&YWGRW zZ>Le0(#4h1M}y_(T6s5T4^X}YWh=4Sy~uB33ZJ2@bD6DBIk+a44taZ>ut;5OMwkUE~VvQ16|kv+ZVL* zQa{Z;4lXHB7s$nroPe41YZX=kpI1Y7Q+9&mfvb}X#mT`6?v9ouNTS5y_7sY%n~ao{ z^Uoql7qXowObXm1u{mxHzSdse%GJgeZe?lh2u4}#0o0^+@8CjGhQ4v%sVsv+ps)xu z0*gd~osA%dkQPB8L`0XZpB`kQEeu#6(lFrsU;F{<4h!}|TlmjBXf1?(aMEHP5*T}6 z0sTD>g9E1O$~-*e8C_A204_89kq3E3SJs2J&EN9C1Mq>b%tJ#P^6&LfSp3Ry!Bq-y z|K*Q*&_4THTO<|%E%ht%P`H(Ckyz*=#>#qlV4eP+hX5A*$}tdl$Sho054v=^B2OB! zuvg@vAm{hWJUnDrK);_dg5>^3JzxO-K_426U&(s}S{hhxEBeErAT|Ge3=HlMxgjw4 zKY5Rl1{UIqwpi$j=E^+GpRxo8kU#4YfOWc}Ep8S0;xNz=5BmM$3pjxM(I1ZRhpqur z|4$u9qaa`YZ~Y-r(vUCz_dE;+oT7fuLl6+Fnip-_LW8z?M< zxBT7~hgh!96tbm*6Ol~)zdsBdynvMqOh0`W7idwX8pIkTTVVG>iz?OBb)#64Db!^Q Qi35urOh`yYM-}$J04$Y3>Hq)$ literal 0 HcmV?d00001 diff --git a/tools/figures/output/logsnr.pdf b/tools/figures/output/logsnr.pdf new file mode 100644 index 0000000000000000000000000000000000000000..3d46a632019150614bac59a91684729aef95ea84 GIT binary patch literal 15346 zcmd_R2{e`8_dlMBxI`fuq^m*{cODR#BSWUl$}H10%iN&MB$Ny#GEbQ@PZ32$nL;8% zk&4Jv$nQKC<%7@n{a)Yo|F891>!-Cl&vV9o_St8j{o3cboJx|?d?2-z z!I5wqLkrl+lW>HZtE~wfAz|QXU}a+lM<^MXnK;1Fpn?iqL_*8#&yN&{9*TpZyDS!=+8_)03VlFGvM*06)xI01-nhW53W_;c;p5-KB+!-kzC*m#eJ2+#CD8 zgb%)dry5u|ykA1L+!IszICgeuYIJAeVA* zxly@KS>yief#*XHJml&+2yI2OYZ`0-2Bcj3{Dk&BY0h^djGq!Vs3KE zmo>#KW$bB#VRwaGL;J+!5w{lvvJc%9mZ|&prYv@FQ`ju2^YxJps(w5Z*cVO3H$(AYSga*~ZDAZjU)$3;%(U6xu!vCeZ5?UA}7 zn+p~5vl*8w%DB@+Vx{h6rfnbNDtO=AS;6yd$dgyHSE@;j{^LXUH1}hVORPS7BP6sJ zpZ|2$Zq_?r?ol=|ar#2}udy#x3kwTh#`}cb=D*FJs(M#BXW;I9vsJ zlSs~reYAj_F8|p-{c!m5Zwqa!vXmRCuZpEG3}U2H72Ur}J*&EAKC-wmN$^p>>MV}l zmixRx`K6fE{OD)=B~e?u9vbm8CWBYTXGD~b@2C`2ldRdn#=y^e&LXWq(BG*ms@Qk* z^J&4c5&1I-NsP$;_N8WUKvttZoR+;+PAPn6#~A=l5lTN~~$t8FCUKFGtNdvjnsqjfaf z(VdiZ_jvrR$=fcsm5#fZW1S@vwN4?W2`Z9dly8s4@ei07==iICo;i^&(QX^g&TD3o z(u1fQe|3f!m-^g}-ZLtwxr1%@RNaI40cP*8miyOU&>(*5BNwb|jJg`0ei-8H|8%B9 zL*tC^YxH^3f!z**?o8BHGFt9o90yu@9gA*=Ym1}4$#P|f{`y%ulwP`%FZ^Tqbbt~u zg+s6%7q`cKMCDVI=S=3slEJ3%;O=nUcfB8UQagrIucI&3%XZvVdHqY0BbUb>!xZ*g zn-m?ET~{d*=%Z;F8uOH)NGP^#zu}OO{D8ev|95k_Pmj6V`X#O0l0S}IoZG5X;@WM- zh@LG=T6*FqX?>8J*#6l$K1w)oK=!C0bDrWejpAnA$gElg-Bb&@`Cr(G;4nqGU zcSLFw*Q|pShuM+NqK+3w1p0rW>7?89`54 zL;k9hjZXcs(VslVY~c(|8At6-N^)?=A1c-R(kn?!w~MuZCQw>pLlf`Q%2aXsQFGc5 z=B?OPcZ(8Hp7V$JPV^zF-JvWc9nS|T+i#!%6nYAa&_AR zO32Ho_t#_Z|22me`^03H{)&dlY5R1rLR<-g6jlz461x1^Z1FJ{OzO@zyt;BSQWc z%xvG~^M|sXE2C9>tCLNK&upo%d*4yMFw`VlPcP&oX>2o!+&!?%`+%B)kjZ%xqw!=O zFI_UPp;C5FVn~Z*L5E$pz>CmBM9yo2szUUG>KW%_wvOqSv9o*U<>4p~X(!z(=q<8! zvf!KdiL5ZW;#=mbRziI~!p35X$k450(J&^Sc;!kszht=sIx~dfIvZhU;e!;uMvsx( zeCZ*t)33_hap-e&QN-O3+tMVbnSY@pO@a0CUeuLqd9i9zrj7PA!&(f(24?+8r(lme ze99SY-cDv}7>Acvr7!HnQzmL&I9*2@ef?9#@Lu}}k0Rf+&j+z@!j5?qNVdKTV9+}p zMD(v>``x0dWOlbTlULS;qBXfm`AlozlhmAjqUrf)=$D0@&VK3k_OZHA~(35@@d?z5E{xod5Ts=cY>dvjM z+)Er~4REWkq{=Vv<60{{UE3Y4Ra>XS?>Fhrn!M#2k7*W}?R{(xApT)0uw zoc=GFchv?yMe&uV+2GzYY@X`*XJaIq6s#F%RZdyoi->8`FlvNLBZ$ZLs&rg_xuJsU8Id6s*35p`SWb=Vn@5_h^=8`1yI2-{297w< z{tlMAUw!@McgIoM1CO%vo=O{iWi1RtQ=-HanWOetMt7+NT)bF!S0S_DrE7zAKnx70 z_ap4_DBG*=0o>m&DDK23s_<*69vvy%yG=b)*6QNxb8a{HSRRY(r4G@}OD2h$N_sYQ zOT0)Qprw%;*5zar#!-Ql(%P*0JVfzR7?p){|3-z>3zFnv!W1QuCi1f-*nkY?azwq|L z&(P*)19r%cXEPHrA#5+I9+r5H^sotqKWmyt@e9d3n=u5@TndxVyJ`#h^b7hA+nVqb} zh4709V@}+Z&w9o+@UV5}s@U&N%r&FtLm|TNh%-Gy@xJJ0#pc2#*&ey^I1ViVg|iW) z0C$x=0RqC22iq$s?zdGOUq~2Jn0v5>;NuE&QoS+#f#rmL?`0ZzxSS5My4utrCVRIt zO!i)9`1OZ^E=*O>90F7|;?8t(#QliKs5_I$h&#sX%`BJ4>lNBH$K|nW&J>We2lIar zl?NjzW1Qok%lxkQU9A5=zjFV93FZD6=ozGc2`UrokC_OpAq+fx3K(3adcxo#Y$Xjl zAShi!`>HN_d(?oq9K?5Ab$%>#(w`$x{3 zdwh7)Ni1ub3mf&M*6ZAdt4y?DZfg9PI>E`qd4mX*S853bI(yegN^0YVkNYFc^Qea( zDF^Mppb;MB`t8i%Hf|@+vaq=;H2WVe-hRs3pPBt$SJoqU`pJXSW}$Oe@+W@Ub52~X z)y%r=ZWJ+>XyEW}G2`;Xr`#X8ZsgkwobmWj99_nb1F2)pW6L2>FnV0XY0p@~gsk+# z8(*%qr%uoioVYdi#~pP_uhU{;!;18KIwob4$oGR+v$4ne3Le zRdmzh%dm78PR`{?Y0C4T+BMtH-JJi`Zv1Ta_|hI6X3?{%&TIa)O>XX~@7`k#zw%yY zwa4B2O>1p^lL11h7okYb=8jHhtbh(~?*lQ4|va3+4Y1{TEK*Ni+W>(pC2u7ldw^?Ig`!8P%1lT|BSY$jE}gD z7em?Jrqy9+XPRPur1pravlr5$p*xVSl`g+_t1!Nak5;(yO>m4L^LMH-oAN4+{u^FP z6hlVpA2yNK+9<>d{>NZ3|12+@iV%qtMfYWxxmI@_tUdz2B(W~19DiKBudXH5+e64|w-yYH!1o zIkbaoxia1-qo#vp-&N)eAZbomCv?UhZf-AV9lCcW-b-0UiGappZnZ|(@s~t5ATkiS zJ5|`Eq$jy0CIe{v@98lGUUJ~5UlO5;L+|Zr9t`B@Iyk0(bTH-d^Z4?T-vM`ZY!cNd zT^w(lg_L+6H9Z~4rmYxv##p#xu>4rXYh4_nO7CLwvEG;E7mN0djw2RMUXDApL?LK% z`Ti#B1I1$2=`;RcEDt6hse)8OQW|0$94H-*$lYL;;8#T+iKOL<*k&4~eRrD{SADX~ ziEVIeV?+8B4s^2Z{sbZK<2$!7r!&iL!Uhy8TgL_(_iwh9vH)p2P?JTNmp;Xpg2nqZ z*O>E{Y$VLXvSeh?Qg2G|Ts_|zCAuYto<;^Glx`jOHUGe#K50PFML$hL<83Pb(srlt z*4;OFj)xZ)OzaoAgqa;ZF_-LPnBGU^_0CP}ZanZFdLdiJX@20* zHld7fI>txzo~rIvYzcZzy5oy`7psvX7WtSyXz`p_TY-aA`;GqV6E9yk)m_ZaIe+r@ z-F)tAOqJ8kd6>L_S1R8R4dJ|ec374VFCu+D(Y+3N;(E8pDb6@H_nhAS@i#6znaq=? zK79_t|6 zLC?Y+sm^D&9|xaBG#UjC4F=fxiM%k3n$Phh>V0rDpL1m%P)5W1KNb#r{yOS#`z%(9 z3vpOOk`cur+x+>P$tOnEPo@vV^d3`ZX4r^^FFeX{pTGHuP~6^otc<~ay7SxKhhw%M zcbvI-2`?~1JK``*`_SjgOO<`b9}a1V5*{&1wZs(PZ%fjT`Mxy^-S@P@04|cpM)yS| zy8ohvUGH3z?yer24_v&1+AdVfuREVtyrFvR!D+RajOFEtFTt*q+Nt!p$ynJ)u zoVxT(>)ZqvOUa&$o~5ZDVm5J?DZ&WegPW`m6yROAKDd8zoMBMdb6~6Cl#ytx1`>}^ zL*gH+$)JX|23*lB(M9I6A!!qNeQDEA?nvhKJr3Ap84(yXD6^zG6SF|$CzeJp$Rfg_qBSk*l6|AN?y}T5?29+O83zagRp%`<_ zzG#WIC;N7l&}Cz{9F1V9--IJ5B)*O#9P)q5k%-^z7nDpdDDp-BFr%@5OR>>T&P_&M z8?0YBme9n1VQ-{VWW%rNurTg+?aQS{yPZ=UWaKENsh>wTVtFXz>^XkO zp0d>BQ*ropsPcLNR|#_$+I~mCYg~D?mGX0*z`F+?-KvQ(Ih?S|$45arxlQ&>vd-tA z;Jo%|gdnkov5%0_gyZ7>K4>aG2&7P!pWZ!Wu%lVXD#X-NG zTWRkoQu^pS9fb53y2S?VUS|rWcvNFW1%I22#`LpmQ+OWnK4a7Sv@9Rd!&p}He%`*( zxld)WM}>BGt+1AItHa&+{rr4L&Y8uZ&ANj)uf7sbyOWzpcoXLhjrtcS3r?wI=Qd!` zg>SG>>yPH!+^4;U?Wzhs#$9*yWK4q0@OgBd0y|Rnl7H;v)0_knN1OpB`ZKSZpR}JI z*X(P-$C(HC+Pq7DB`+l_T@`kxiZAD3O{LYNpq|||ZlDw6>Zw-TPZ+%FJkjQC{Z-hw zN1ig=_LK2L+SZoapO-wxCdNpaKlhQ!v?2t9Qj)sQr$6OgbY*&N~VBxNEGM)xWw$nzsiscg+{=#xW%xrjQ-Vk37m%IGE(X5ec^+ zkn^d3A~;v@xxaZ%$C6dC_H)$_tD>D=^j016drXrGZFjy1roM6Z8~bdRTX3po>GSBj z9A-5cOZz|H&UjbJz{BnCec*744SJtChoSzbj0nCq%bds%wnrC{>bul-zN+|M zzOcWqR=qm#6DD)+bMa_YJq|oaG!U0v87j4Oxduxj-E>n=a&QEV5xlJ9)e_=*( zVX`Di6tyHc#FJArXLgjWEdXM6yjCE zd)wMA>fH1C6Bici(~jCq`bl)~W86ayoG{oTm(%m1)|k6F%`nA{F?%}CIC?CkyZBx_ zm$mG9|3R4#mDsCg8t;rWgldlb7^jM0#V{8>N_=qg12Kh7E6HCdq2?!sC;j4#xP#CzQxji zI)!d}kL9@x%(KrgU$tC)p`vr>2n8(_m*VfsHD(?+XAvD4Ui&v$qD_2C6bAjzc3tYG z7`+4sSfbc#N2m?o{OH8$b)x z+p+h-iR-u7pLh2k44Cf7A4x!X{?>w{?jpnu988v*ro3I*NL!pP$&7&zLS zldp66kw|FwfPa6UpqpI4p%S=d2A8)hSHMyva2E^RLI1t`Ud@6bB+X4tO~4H@bY-jq z$AZaRbhW)`LV^>BaD*b92-aBzj!*?R)!;H1STqxm0JCof5>VG1B%q!Z9AOQZ1jROR zge?q(0!0pRgd^Y?+yEmq%#9r{0)EI>;%o1;*Sau&7h{bi$us^3u~`|IIlwXKW$FBn z0)!*P9E_mLasjje903gjz3`ziz}`q0*vgogn_YB-R#$WN5nRQCj|9w(#7JgVCP2~% zRYwzRb?8CCz=b?5I0}nJuAS4WTc`uaBa#1K5bS^34+cvF3yO#1&?q<&Lxkf6@Guk- ziG~viXgCIm2ZaPO;m{az1q=}h$AVRYQVbpkHTZ7?TU&+&6NLyGWR-BB4=@-KP5@3R z0gr|gFj!CsBtpcHiAd&(0K-7c03rYt$72X^0RbGSM}QN7>qx)=#-J4J1_FrH;DI_4 zLBI*>69wrZfq()9U;s1@2F(WI8cHE9aA0tV@StM?*hx46P#>gtpxIc=avlM627xMw z@SrV7frzkR1W_8fXy~jt4si51w&=GjeC36DS1?EUz0R zLcj!Bv|vy?V2&)7RqmjatU}QIu%IOhg(r}kfH|$I4~)DbK%!7{GGPSBbNkcP@&m*m z)CE!ia0X2u;tFuNstJ$=kQWYUJd{G>T4iLp2Rt+-JXwB_Cj99iss#xZyF%a{Mu2(& z1ThDdLuI&?4#46-tylkpw1KP$kTR?zU_o%8FEH?mW~`>Gnz5QLtJbn+Kw9x11jZ)6 zqtGx&Lsp5b9H5MV5+ExAqzvl`EI42sEHI=QYY4Wg8LR2CLM$r=^t_fJ=2pug1z06W zIe-Gep!9DjbS+*{1U4R66kugl;N*BD+1`TPLC#ZyqgTGKF@b|^O3pKbgEa+DuvJ?P zG=*GX4nFExezJn2S3Xd(hNH>n0A!ni*(2Ap1>1l530O|@Krmpe!I=(~{`~|B?7|Iz z4g+-TpF_4Pa-=#@>UYD;x(`~u$27aDoO@qBZ(}K;lEz!^kE%y{Hinjbl+$ChyU|c| zK3nHxvIoD8c9g4MSzs`yNMTRp7)PO=BY!UO;OR+P^gUk5)K}?m;^)Ng+}IJS_&uJs z##7+;AmYlCw>dkD`-RbBIY&S7h}V;p_eC>`?KT*`wI?-EDNd`VPkHRiESl27CMb#4 zh;53dfKuFA=IPc^(Jhi+Bh<1&L(GmcRqc6iX2_;L{3eegd!L%*9rv%#{L-3IgUThJ zikwyWb?lBjpL#t>b>D@Bh@dtPM@rr_>}15`>vQ?9X*K;_8D|6*WafrDtT=oGNgb2^ z9}i*b(2Py`asvACk%Ol{ywh^poqe#RxXJPAq1zp{@&#YKo=>0s?f6dnl~cJrb%N+e z1+njC^x4w6PoJVi?gjI>C7j6FA6+hud1A0@cgJk)^KJVV?aOyzQ!Eb`XdaGvY$zEV zVC5iOvAsinyiTB_mZ`>S{9Z?V9tOk9ajm=5$Ks2qbAIIy+rqe0n2`kQjSdh{(+z71 zti?YrsT7F9SHr2dFR6*qz+l~Hxx#&CmV$SoJ@kj^$F5V-7P?9wqP)#@evjskGe$j1 zH$^M@?fDvdC>dZrQTu2h3xmM8Wa@m)DiyfF=9K;iQ7fc2*;) zCEb(tw|jAc?}j33daHKs;v^MwzZcd$NE5|F!qiGo-KRU@qi1(@>JE}t_q0;Ty{``q z+ABx52Tv!f{1|75xc@bCz~>9aR`LC&o3OSCy81U^SPZ^E3B;(d$Q@zQKB9MS`HKf| zu*z(C`tsv7non1>s6_WC5O-vAEr+N23wGY!Wav#$0|uD4e-sN2*bYh*^}A-CV$r1U zXPbVcrw3L7+rd<^M;@s^$b-;ok*jRue3w)^q;qJNGgrL#=C|mVDs{6g`OM9Hk4`u5 zF5-i2S7s8{j6&==+}1obBA9A>Chn>RbrJJ{a&Or!-WGE^*po1DrD4(fGDEb=wrfW_ z9n9^h{q4E@FDI7Y@8ix&jGq$Lih9u(E|>KBpn~EEMO|Z^>;pa7pAVw$o{8zCVYu?USL&wXu(bxJNa9-;~4kYHi2vyt+jF=Hl-IER7$G_JmzI zos!GA)0|lCSAG9HvF`j2CIR|tKcjomC+LhWa;3P1_*cKapPyt<|BPLG`UivBXAj>^ z7UF*!i|138wcjR2UDkYj>bhx4pHbv--=U?q)Qvl+b_R+%Y7aaJii4lfDF1jX<6>Y4&CBqP96HlXhJ_%Wo{_qK)IPJKwqt7hipPz1Z49X~fKB9m=u8o!NI**tSn zRHk~7#4zoU7vAS0H{8V1I_UnYVX{U|zb@u2Nfh5c&{ys^I3M(Ac=+eSwhQ_?ew)Z} zlUO_&xJUo2%*q1v+c@aH&hGgwdNc^(ayH@EtKU_xg5IdJKhgCmh~S8tR0&Vz_^2x4 z7R4W1+Vl3M1$OvO0iVN_;=*!7R#>&ash{@~^Si|q(I+V+-o8mymSAOn-IkT6InB)mU>8rbPI zfO;9qt$~5SIVAfne@s6aPy>8Q2Gkq6UJnH~Js2R&{y$@B*cvPi++4r`46zEF{|%M~ zsM6NN-rU9*4*W}mn~6OD$2KGo0dTxz13(!ZVG2IghU38D>i_^fba$`}K>q_YgHvps zQ?^zY|H^uewAO&;4Wa>3|1WrXQz8BTftN7=Uaj#lSIyJ^1$Y?*Ot2uvfyFOl9e{0M z0ucHH-~|nE14aPgI6$jd2m|Va+Z5nIgB1A65El9$2)0@V)rDvo&n!b`9B{+QU<-Vl z4{!to&j4rvuon-!avX$UAp%N+7Qn#(0Tc)jErfJ|!$Sc2MS$=KgovRODu=)(h{zB? zd#FzVA^~9ZMuG@O5Fn=l z0L|ec@CJa_avcbRqR2r9i~!UQ3$PUg$RLab<^mCfY$4#dA~Z5~g2sjf2vQ(W=xY)< z4z&3X1PU>Tg7gO(35qz7<7DLc2sAaQi{<9zHG!G~g5g%y>`zzA4?y9Uy8y5XYDVS? za1Jdeqz0?YwVDEMVXIOi_kab#2QY22_#i#_(?3)TniE->Ku=`3t|$l8J9vkY%T_uN zzze`v2~vb*9f0)UKZz`Bpcjx%tRb>KKsvEXWL;U&2}mE-62uqz9rOX|#wtO&f=A-O z8m{O7qzCIr0E83BdI63;{2E%-iPaPm3Rrg-Ifb6r5P7Ya%b_}d2+|8^RiX3`P6jgj zd#)@SN1!s~qkwEX$zTS+Wyq?M-Ei={YI0ZbGh|wU!jO$OWOqTJhuofQj^oKAqfTJLbhY&;m*J>bix>9Y8ZCrDz z{$+8ENd9K71?mdW9$B8G0#M48qf;8iwR zN%QN`T*cLAM~sfOlnUW#edNZU}eA z@C)#xPrxraI@&r2A`sSpiui5p&3Ir4aARTYWCVUbwzie6u_@fpz{nDeveE;nN$%d< zh9n97=I)rJAR39rA<^KrpMVv>2^>cfxsgb2o;AD3#ooje2BMRA7zjtN{DI>W2X?e6 z{7)W;@si8nq?J4rh@^u30sZ~W!(i|m>S3_RjrDNgT4a4&+$Q}&=JY@7fj1D9U7rVq zAlKz#1UB#jg8LiBMPdY?<$?az%mWWjzx8=U;9~q|9+HRwrh0uIIOo>q0Z(NEe*zf7 zhCGY_aBbH0hXp74hCC>iyuKb0xXSDEAh%{+9ymcajEf^|kUbJd1n$$iw!jbA&=#Ct z8{~q=0{3TqTf)Y5BOo`{1A4lCZYVT*gG^B<0T7b@&oRK=z=nB%t61FndPo5rbR0r| zf7cwmY#0{_cFl%*7;piwzAX;8G3)d2P!t;a`#T1>B3v(16j*i>0U~Xk% zPyTJLs<|6D%K;Wrv9W=S4EeB@C7FVg8?3@|c-q0yz}}I(tx;IOER2&=>Z~;E{{Ti3 Bz1aW& literal 0 HcmV?d00001 diff --git a/tools/figures/output/mean_sim.png b/tools/figures/output/mean_sim.png new file mode 100644 index 0000000000000000000000000000000000000000..f4b79672aa3fbc8b03963a35a2c41499ec48a177 GIT binary patch literal 20988 zcmbrmby!qwzc#!80qGD4DFKlN0TqxGK_n!G?h+AckWLk(dnhRp5fN$W5|tc6kp>+k zrDK3$$oCq!@8`MqKK8S}_jtcQJQSQ+GizP(JAY?f5t`~sWW<+*i*b%7?w5bB?#UdS2cjL1Oy6HpcBs^I0nc z4m|mdKAu-i;csnyPN=1={k*>Z)^#y4JcwU#h??rVoUI|PneQU!r0;TZd&l8H%ZX$l z24Oz^(+hbYDHe3>f4Ha-^!Znl9}cv$vlHHCwYO6mloi+`)`GcD%6yBaU*>=xve=sK z+FR0ZiEM9g-?na#VGJ{NZo0^)_d#x|Gl7F!Kp@n{5dT;>BK%zora4Z#K;3fiomScE zv7y4s(vm!T0(WP%KXU!XFYtImzWi>iY4Tb(8yp^fI4Kse|HNUm45#Y>b*lxXA9}4# z8a?@AS=fH~0}gbBnutl<1)`AJkZZSoY;Rv~+s^{YS#1CL?jvi+&R$bs zYGM_9jLqEnkJZ`yox3V3D&M|;PjE$kp^%NXL8Qzztqu?C>FDS@nAF$OdaA=4T3vlJ znn_|{FAe-2k~21DG#RlI6^M+cCrjYBC@4#ICfXn@M9c_G>%gi_veSs z(QTX14ddjWs|m99#5pmNCqD=-1Buax0@i+Ux^IadX`DG zuL>^6YxB!{GD$KB7d7S9EyK1eo&ES`wa3>J_U&z)J@Cjc)&!MMN1QV<)|k%wGVl7j z)T&&RUUEN7+<)&{cX#&yn$D`)XV})Z(l+jH;Qp~g9A+h8z#13Q)T322G~AV4iqIL@ zJD22i?+{160I$7TE^N#8?VY{n%n1$cU*;VN<+mUk!rp!coLn+F5et?{=rRv!U<#M> z1iCkXSv62Vd*D;yq2=tSF)t=@(XsAS5xa>L-}%sLM|Wct%oKqDif`Iz>(k<=L%+wy zTkEwp%0U?)@7Oh+Y|t4gE7NmwV#r*wM~=o*l9O#nGebqJzJ_mq*!W=G5z9t-e(H6U zAT_JI=L-q{Js0s<4w)CWU@j+gqqXttb_xo$3tkr%^2{`+ng=pm;8It(Mo+`4ZE8ve zoe!0N_x}CTPTJeIZ?~;yOXU<46zJ|LDG?qo1RZl7ty9PzQaU$duD&2+e4#4TL_KHe#{p6gDool7y^ zh%!IKEp6N)_FVr3mdbQn6qvC5YMJ6xVCuKkdn|vOoQ%ex7VxI9%ROM~K+xY@LT)n5 zadtXJ?ANEA5%Zj5G!#y)NJ>h|lWJ;+myfU9x`TYuyLKbo*UFZ;Sp>(baOy9GbeFrsLQ?uJH+bhszI~AgprZa zwZKE~mC-K1zZ%r<1C#zp?{iZL#Sa@yi<+QQh8r)AQ7x?yAhLRlG+e-tJItfhb z7Yq}(CVIHYaXu|3&Eu^UsCRLlFs;W^^yU0XR3}et`I6Z;Gkck3MoQK@T2%Wrs|I-_ z-$aT!w4~O>YO*-{N^oxWm9@B~L)2JDTXT@=kvYnJ>DCo6TfN17wjQ)FiAA?raSI74 z*-u#GOGrHWy5V=QT9+*()efE!{5kZESwc8VnvodJC~5qqncYc>>Sq5MEPp*B;Q5yx zQCLCo&?o`)-3{e%V{b&g-1$xKA?j`%0vZ3hW5&u)t4E4sRi)dO+v{r2s%eS;%~C#p zNv{A*LlRdH@ke`gpQZHRqdOUvhM##djF9d>oy@C$n)j<_6st_A^A<4-RBQWu2zgTH zX6~3gr`A3BYq7i3rks*&*=??JE2|dcY!+c}=86k-2%oTU-LZ{!Y+h%}mg0A6pd@y6 z(+J98deXc>&qYO5SRdy)FX=Mx75(JZ6ifgV5pwaFE9~sa>-kfDK7Hb-6s&|KqEnPt8Kb{HSxsJAxd&9dHe#E`!)w< z0-7&W6!)5vF7$~wt$)AbIVOwhArW1i%n!~MmRv#i4e^9TI&#=&p6#PcN?jcCzxlLi zN#?WcWs?NH%53@c4S1q{dIq}Etg_S1-_(!EyGJJ+pT{_WPx-leQmjO(^p#8Nb}+0G zRxcj#EyWxWqWPg!orlhGBaD7R{hXjkHKy6mY?^xm7HliH|9!!c6xZZECp)_?8diK5 z`c`Oyp0rj+6^tfcpLj(0y@h;8sD*(m^>IY_AJlU^+e_Ggg}Z%X0QbD})Hs-go;22e z$PN2ElIQX*c!XayJ6K;k56S(R@uIwy6+oJ-@?2TsKIF1eRxZHQq_DzR(_`&Bkka8N z`Q0W^-PN5D5;yA^rB(9?58yC$v%4l{sh;+jv+xQ|?kxRf+s%(Xd4u`_@%~TUqA=S- z_v#dlJ5wG9-0+zbAF`U8uYTjErs``J8*e;ZU3?^CpmJA-|C8YEIoM~^!2Vq(wvqRw zx4y$MXmN+s^ZkobWOvPoJKd*dh1J>FU-z?NYimNRXy0}6Oa+BE6^AtZbKlg>tA*)Z z3RA+vxd}5}!T~F;@Bng>|D)NM=j^ymE02dKzj5d z<4kH!rn!v<6$f%QqmIW*=jEk^Rs-T|1p2SNaa^A3>8Yd`Nj$jKZENgU zp-;lP5@ebku#qweE&JjHMQb}CNSG4yD2Ho`J z7Yj&q6VX_r->cLDme3w7?FDf{D>uwhQ+rT1967c4u++>R2;r~wI{$nPq8t83#%= z^<>yxod`?j(|=3YXZhuY&GuNeBV%M_q%T?KQ@P95DHuEK(tVA?m2hwB(^I^6#l^j~ zE!EYue!DAO<6JCp9MoM>M532stWJ)P?p;kXJix*&^#Y}q)_>4#P_RBkaM+fMc;+kJ7{g3)zO8JO@+Wag_(#z^-f9H{- zqoRk$jq!;FZ3Y}@V;rYH9_}(;A$uD-uMipI2*rQw2{bcs&ZDg-2d{9#& z)+)auhz|!^&|aQ4oY3gvL_D*Vw`%xGmvWCuN$`7$8bOTXQ#g}YJUbUm2R~jl5l4kd zkvs40ca?nQ!7V4ekJr;i#p2LiTCdg1LTeMLpHm^c#@x&Wtdjbly{57GUj3m>_c*~Y zI|q$V%3pd=#Vmab<-GkD^t4$=rANoDA2h8=+Aj{|gF9%0^Jq^U$nf&=Uh~^^_+V7U zJ3~Kb$avG?)86Kcu-(9$=!4~@rGXP^%g^W3vTj`j7)Ro!^RMW(C^~m+rCb>K{Uyv5 z;PLI^+}zv|>lYNPCdYM73kn!;CQ+Iy&pmAw=R&{!l{L?M%y@#-*_O!w3EuT5`164U$tVC9kcm z-F|fc{{43)CA2w@uhDgfXm^XS0KkR)O8gYTkYH6_0pvH;$Wk!ZD=t3a5J9qG6 zS)^a~J#IcZJ{WBYlD(>t9X-<#XT7sD5Z97tz1KcyQ@9go%lXH z{31%InO^&?T5EH_!N9@XX@fWX44OAMtr9)FT3 zm{W7OY#CT!l1m1`ctHJ_7f+6UC-mS#75`g!#FH%d^$ZOkv6x-ouB^F-g-kC?;Zup@ z2dMr`E~y3o?squB+ZX({5IFkuOdp$-f8(q2jQzOM)6&6{hG6!UmtA5#Ep`7cg;Esb zp>OspVp@ag9Zxb9r4-?+)KDj(TZMRoSyEBZsSd@^4r^~p0J&k?&i>7Cqi0m}G; z$O*FU+-aTEy3F)J0_f=*%#KSaU?}2Y)7;sVt7SvnQJ57O$Otd0+j@y|B1A%PZU1T>(*sB_hX74@H;9XD9Q4uH=k93aB+mkRm*~xTmPjAOo zlL1T#%LCn-R_wsZaM3F*wpxRmQ;;|-rc&Tlr$fpl{^P&kAPlCT9wp7^6%6in(ST^h zVpsV3fS%Ke1ODt`{7hRe^+!B);>f+V+>5GA?k11Bbt+bbFG;wuX+ z7i}wy+nRCG+_T3IEOKJ$Co(k+R``;c-Dlile`u+s_{evO6kkc#`GwGmobv@g^e3V1 zNdkrvAE?|NtnNL^-*O60Y}2s{Z03Am73$_G7dZCg_=>dssz%FT$vz9cSrOC?wi+^lb7&!{1v?L{>Y@V)y)t;|9BcSnYpOz}O);J`oKvf1zI zLYXXqRioK1VBdcAdq)=0PGxO;`GMdO0kq45(pWtY;E_vz8X~sN)D5l}lXspI0W&h^ zfxT7`d(&*R@`WmoUu9B($Ls3Oh)R%hKvi0Ft!aMgM@OYP_F6tB9%^cl0QRN}M?U^4 zm*}JfV%gOe-gI{PMCI+_r@4|RK27QXGxP5o-K4j@M2NjB*D1A`{{rMakXt|7Ek2!& zv&sAWfBh#eTO3@n0g4r$8r*9-`9(koB+;IR< z6(}$?aSI*&WtSnE6j3H2L;~IZa)1Y^4`%w{LkELA?i0$Kc)s+Hb3u#^^&cOFQWrf^N9>6f^A(;j1e8@QM`KwBX#z}6O zMzq@9Z^u=fZm>Q!6c&!+LEk4@9oK+=> zHvz&T{nGYK3HNJWYbMc*V&s4vYW?#)@B8Rz+RkRb6pD1dHzRC+-`B)*%qC63{px&A z`f~t$9!m99@JG`Nv(Ka;&MwxC{_t#apZ7l?9@`!!oXh_LJc|dnZr!@}u!r8nt&6i3 zGoRtQG1Z#6i#Q#4e5rPZs5TL&P4D%jobmDTsd_xp%_$8_ltW z>xo$mi}aB+?%|59?3WJKpy&h+C99wH9)Nmk+xY#Za<|-( zam3Wi)VF;s6)e%I2aV^ubapchv_4#4K7uUlc6LLgOy6mj;PF%uCLdaKO4>P8YO_WL zd?9i-Q+Z!N)CoE}Zr)4QzdRp>kBYJI-mJz~m7zSi#TCmiu>zC0!1qv2BV9Bbn#xKk z%DH7azwVeFk3#^7IXR(2{JAfW;3SVjb?WajU`vi~XX_(iu|GWF$QzbOBLNKrwOO3h zV;=%j%tEq?=v979)e{HIPQ2TzagmIyg;9%X9^3ZL9r6C+;&?adyXn|RVXmi_BJA?; z0&%D&Aa~!sz3RVgj8s7)2M0rpo$3@<7{Rmd;!K60YR8Ee8P3fN)mcYb=m1tMn{-tz z?c2hU8+P^J4g6;^2)cCXk_#X&pXi~E(8!MCqXS{Ui!jLsGGXU8AXGzpGsFsNFr6H6 z7TAHDrfUf|;C}%gJrnj*vVE5dW!6vE*kU6t|Ioa+L`i9>lAoVsijaA1N(yZ>t1OG8 z=gOxv;h#lb0s;aI>FOgMo1@loznjp%?f~Qlc&9jNzg-K!3@vtKio4<>4}PHmCXd|O za`T_#GjC-*-XK(adA+dzpbFYEF2#^XX0E7e=5@z12EEIjo9EPm!fDfmGU#{{Q!(;w z^UbX=OodM_F$z_2w>$MIitjGN1a+Pu-GMJhQJbf0;Z4q#l&%)$>(%%ZD1-OXY(hiw z2JhQ=d=|$g1%x6oISu*a#|^XDADxSxVcuq&A%`+=q)>qk@Mfal17QV8Zpa@m?x@;& zoR>xM%t_0znRssMz*xkM`qNUaHQpq)--N`i(t|WIlID>?Nc6Yl!xq3Zy{SL+5wt3& z??xatyW%1>X}g@KqCGIR-gdSv(j%x)ERP5(78ZQG;hpWeeey^A6n9mUxJ0|V49<`O zJn{w)fslFONz`$v?@8At$>MeILMAQqaZb-se||Q~__#PD&s{zkBzNs{Q2#eTQ6wa3 zzA{-^PL30;=48eMY