decoupled dit configs
This commit is contained in:
108
configs/repa_improved_ddt_xlen22de6_512.yaml
Normal file
108
configs/repa_improved_ddt_xlen22de6_512.yaml
Normal file
@@ -0,0 +1,108 @@
|
||||
# lightning.pytorch==2.4.0
|
||||
seed_everything: true
|
||||
tags:
|
||||
exp: &exp res512_fromscratch_repa_flatten_condit22_dit6_fixt_xl
|
||||
torch_hub_dir: /mnt/bn/wangshuai6/torch_hub
|
||||
huggingface_cache_dir: null
|
||||
trainer:
|
||||
default_root_dir: /mnt/bn/wangshuai6/universal_flow_workdirs
|
||||
accelerator: auto
|
||||
strategy: auto
|
||||
devices: auto
|
||||
num_nodes: 1
|
||||
precision: bf16-mixed
|
||||
logger:
|
||||
class_path: lightning.pytorch.loggers.WandbLogger
|
||||
init_args:
|
||||
project: universal_flow
|
||||
name: *exp
|
||||
num_sanity_val_steps: 0
|
||||
max_steps: 4000000
|
||||
val_check_interval: 4000000
|
||||
check_val_every_n_epoch: null
|
||||
log_every_n_steps: 50
|
||||
deterministic: null
|
||||
inference_mode: true
|
||||
use_distributed_sampler: false
|
||||
callbacks:
|
||||
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||
init_args:
|
||||
every_n_train_steps: 10000
|
||||
save_top_k: -1
|
||||
save_last: true
|
||||
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||
init_args:
|
||||
save_dir: val
|
||||
plugins:
|
||||
- src.plugins.bd_env.BDEnvironment
|
||||
model:
|
||||
vae:
|
||||
class_path: src.models.vae.LatentVAE
|
||||
init_args:
|
||||
precompute: true
|
||||
weight_path: /mnt/bn/wangshuai6/models/sd-vae-ft-ema/
|
||||
denoiser:
|
||||
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
||||
init_args:
|
||||
in_channels: 4
|
||||
patch_size: 2
|
||||
num_groups: 16
|
||||
hidden_size: &hidden_dim 1152
|
||||
num_blocks: 28
|
||||
num_encoder_blocks: 22
|
||||
num_classes: 1000
|
||||
conditioner:
|
||||
class_path: src.models.conditioner.LabelConditioner
|
||||
init_args:
|
||||
null_class: 1000
|
||||
diffusion_trainer:
|
||||
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||
init_args:
|
||||
lognorm_t: true
|
||||
encoder_weight_path: dinov2_vitb14
|
||||
align_layer: 8
|
||||
proj_denoiser_dim: *hidden_dim
|
||||
proj_hidden_dim: *hidden_dim
|
||||
proj_encoder_dim: 768
|
||||
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
diffusion_sampler:
|
||||
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||
init_args:
|
||||
num_steps: 250
|
||||
guidance: 3.0
|
||||
state_refresh_rate: 1
|
||||
guidance_interval_min: 0.3
|
||||
guidance_interval_max: 1.0
|
||||
timeshift: 1.0
|
||||
last_step: 0.04
|
||||
scheduler: *scheduler
|
||||
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||
ema_tracker:
|
||||
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||
init_args:
|
||||
decay: 0.9999
|
||||
optimizer:
|
||||
class_path: torch.optim.AdamW
|
||||
init_args:
|
||||
lr: 1e-4
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.95
|
||||
weight_decay: 0.0
|
||||
data:
|
||||
train_dataset: imagenet512
|
||||
train_root: /mnt/bn/wangshuai6/data/ImageNet/train
|
||||
train_image_size: 512
|
||||
train_batch_size: 16
|
||||
eval_max_num_instances: 50000
|
||||
pred_batch_size: 32
|
||||
pred_num_workers: 4
|
||||
pred_seeds: null
|
||||
pred_selected_classes: null
|
||||
num_classes: 1000
|
||||
latent_shape:
|
||||
- 4
|
||||
- 64
|
||||
- 64
|
||||
Reference in New Issue
Block a user