feat(data): add support for Hugging Face datasets
This commit is contained in:
116
configs/ddt_butterflies_b2_256.yaml
Normal file
116
configs/ddt_butterflies_b2_256.yaml
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
seed_everything: true
|
||||||
|
tags:
|
||||||
|
exp: &exp ddt_butterflies_b2_256
|
||||||
|
torch_hub_dir: null
|
||||||
|
huggingface_cache_dir: null
|
||||||
|
trainer:
|
||||||
|
default_root_dir: workdirs
|
||||||
|
accelerator: auto
|
||||||
|
strategy: auto
|
||||||
|
devices: auto
|
||||||
|
num_nodes: 1
|
||||||
|
precision: bf16-mixed
|
||||||
|
logger:
|
||||||
|
class_path: lightning.pytorch.loggers.WandbLogger
|
||||||
|
init_args:
|
||||||
|
project: ddt_butterflies
|
||||||
|
name: *exp
|
||||||
|
save_dir: workdirs
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
max_steps: 200000
|
||||||
|
val_check_interval: 2000
|
||||||
|
check_val_every_n_epoch: null
|
||||||
|
log_every_n_steps: 50
|
||||||
|
deterministic: null
|
||||||
|
inference_mode: true
|
||||||
|
use_distributed_sampler: false
|
||||||
|
callbacks:
|
||||||
|
- class_path: src.callbacks.model_checkpoint.CheckpointHook
|
||||||
|
init_args:
|
||||||
|
every_n_train_steps: 10000
|
||||||
|
save_top_k: -1
|
||||||
|
save_last: true
|
||||||
|
- class_path: src.callbacks.save_images.SaveImagesHook
|
||||||
|
init_args:
|
||||||
|
save_dir: val
|
||||||
|
max_save_num: 64
|
||||||
|
model:
|
||||||
|
vae:
|
||||||
|
class_path: src.models.vae.LatentVAE
|
||||||
|
init_args:
|
||||||
|
precompute: false
|
||||||
|
weight_path: stabilityai/sd-vae-ft-ema
|
||||||
|
denoiser:
|
||||||
|
class_path: src.models.denoiser.decoupled_improved_dit.DDT
|
||||||
|
init_args:
|
||||||
|
in_channels: 4
|
||||||
|
patch_size: 2
|
||||||
|
num_groups: 12
|
||||||
|
hidden_size: &hidden_dim 768
|
||||||
|
num_blocks: 12
|
||||||
|
num_encoder_blocks: 8
|
||||||
|
num_classes: 1
|
||||||
|
conditioner:
|
||||||
|
class_path: src.models.conditioner.LabelConditioner
|
||||||
|
init_args:
|
||||||
|
null_class: 1
|
||||||
|
diffusion_trainer:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
|
||||||
|
init_args:
|
||||||
|
lognorm_t: true
|
||||||
|
encoder_weight_path: dinov2_vitb14
|
||||||
|
align_layer: 4
|
||||||
|
proj_denoiser_dim: *hidden_dim
|
||||||
|
proj_hidden_dim: *hidden_dim
|
||||||
|
proj_encoder_dim: 768
|
||||||
|
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
diffusion_sampler:
|
||||||
|
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
|
||||||
|
init_args:
|
||||||
|
num_steps: 250
|
||||||
|
guidance: 1.0
|
||||||
|
timeshift: 1.0
|
||||||
|
state_refresh_rate: 1
|
||||||
|
guidance_interval_min: 0.3
|
||||||
|
guidance_interval_max: 1.0
|
||||||
|
scheduler: *scheduler
|
||||||
|
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
|
||||||
|
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
|
||||||
|
last_step: 0.04
|
||||||
|
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
||||||
|
ema_tracker:
|
||||||
|
class_path: src.callbacks.simple_ema.SimpleEMA
|
||||||
|
init_args:
|
||||||
|
decay: 0.9999
|
||||||
|
optimizer:
|
||||||
|
class_path: torch.optim.AdamW
|
||||||
|
init_args:
|
||||||
|
lr: 1e-3
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.95
|
||||||
|
weight_decay: 0.0
|
||||||
|
data:
|
||||||
|
train_dataset: hf_image
|
||||||
|
train_root: ./data/butterflies
|
||||||
|
test_nature_root: null
|
||||||
|
test_gen_root: null
|
||||||
|
train_image_size: 256
|
||||||
|
train_batch_size: 128
|
||||||
|
train_num_workers: 8
|
||||||
|
train_prefetch_factor: 2
|
||||||
|
train_hf_name: huggan/smithsonian_butterflies_subset
|
||||||
|
train_hf_split: train
|
||||||
|
train_hf_image_column: image
|
||||||
|
train_hf_label_column: null
|
||||||
|
train_hf_cache_dir: null
|
||||||
|
eval_max_num_instances: 256
|
||||||
|
pred_batch_size: 32
|
||||||
|
pred_num_workers: 4
|
||||||
|
pred_seeds: null
|
||||||
|
pred_selected_classes: null
|
||||||
|
num_classes: 1
|
||||||
|
latent_shape:
|
||||||
|
- 4
|
||||||
|
- 32
|
||||||
|
- 32
|
||||||
@@ -6,4 +6,5 @@ jsonargparse[signatures]>=4.27.7
|
|||||||
torchvision
|
torchvision
|
||||||
timm
|
timm
|
||||||
accelerate
|
accelerate
|
||||||
gradio
|
gradio
|
||||||
|
datasets
|
||||||
|
|||||||
@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
var_transform_engine: VARTransformEngine = None,
|
var_transform_engine: VARTransformEngine = None,
|
||||||
train_prefetch_factor=2,
|
train_prefetch_factor=2,
|
||||||
train_dataset: str = None,
|
train_dataset: str = None,
|
||||||
|
train_hf_name: str = None,
|
||||||
|
train_hf_split: str = "train",
|
||||||
|
train_hf_image_column: str = "image",
|
||||||
|
train_hf_label_column: str = None,
|
||||||
|
train_hf_cache_dir: str = None,
|
||||||
eval_batch_size=32,
|
eval_batch_size=32,
|
||||||
eval_num_workers=4,
|
eval_num_workers=4,
|
||||||
eval_max_num_instances=50000,
|
eval_max_num_instances=50000,
|
||||||
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_root = train_root
|
self.train_root = train_root
|
||||||
self.train_image_size = train_image_size
|
self.train_image_size = train_image_size
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
self.train_hf_name = train_hf_name
|
||||||
|
self.train_hf_split = train_hf_split
|
||||||
|
self.train_hf_image_column = train_hf_image_column
|
||||||
|
self.train_hf_label_column = train_hf_label_column
|
||||||
|
self.train_hf_cache_dir = train_hf_cache_dir
|
||||||
# stupid data_convert override, just to make nebular happy
|
# stupid data_convert override, just to make nebular happy
|
||||||
self.train_batch_size = train_batch_size
|
self.train_batch_size = train_batch_size
|
||||||
self.train_num_workers = train_num_workers
|
self.train_num_workers = train_num_workers
|
||||||
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
|
|||||||
self.train_dataset = PixImageNet512(
|
self.train_dataset = PixImageNet512(
|
||||||
root=self.train_root,
|
root=self.train_root,
|
||||||
)
|
)
|
||||||
|
elif self.train_dataset == "hf_image":
|
||||||
|
from src.data.dataset.hf_image import HuggingFaceImageDataset
|
||||||
|
if self.train_hf_name is None:
|
||||||
|
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
|
||||||
|
self.train_dataset = HuggingFaceImageDataset(
|
||||||
|
name=self.train_hf_name,
|
||||||
|
split=self.train_hf_split,
|
||||||
|
image_column=self.train_hf_image_column,
|
||||||
|
label_column=self.train_hf_label_column,
|
||||||
|
resolution=self.train_image_size,
|
||||||
|
cache_dir=self.train_hf_cache_dir,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("no such dataset")
|
raise NotImplementedError("no such dataset")
|
||||||
|
|
||||||
|
|||||||
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
BIN
src/utils/__pycache__/patch_bugs.cpython-312.pyc
Normal file
Binary file not shown.
Reference in New Issue
Block a user