feat(data): add support for Hugging Face datasets

This commit is contained in:
game-loader
2026-01-16 14:25:32 +08:00
parent 1d1b4d2913
commit cf45afe325
4 changed files with 140 additions and 1 deletions

View File

@@ -0,0 +1,116 @@
seed_everything: true
tags:
exp: &exp ddt_butterflies_b2_256
torch_hub_dir: null
huggingface_cache_dir: null
trainer:
default_root_dir: workdirs
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ddt_butterflies
name: *exp
save_dir: workdirs
num_sanity_val_steps: 0
max_steps: 200000
val_check_interval: 2000
check_val_every_n_epoch: null
log_every_n_steps: 50
deterministic: null
inference_mode: true
use_distributed_sampler: false
callbacks:
- class_path: src.callbacks.model_checkpoint.CheckpointHook
init_args:
every_n_train_steps: 10000
save_top_k: -1
save_last: true
- class_path: src.callbacks.save_images.SaveImagesHook
init_args:
save_dir: val
max_save_num: 64
model:
vae:
class_path: src.models.vae.LatentVAE
init_args:
precompute: false
weight_path: stabilityai/sd-vae-ft-ema
denoiser:
class_path: src.models.denoiser.decoupled_improved_dit.DDT
init_args:
in_channels: 4
patch_size: 2
num_groups: 12
hidden_size: &hidden_dim 768
num_blocks: 12
num_encoder_blocks: 8
num_classes: 1
conditioner:
class_path: src.models.conditioner.LabelConditioner
init_args:
null_class: 1
diffusion_trainer:
class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer
init_args:
lognorm_t: true
encoder_weight_path: dinov2_vitb14
align_layer: 4
proj_denoiser_dim: *hidden_dim
proj_hidden_dim: *hidden_dim
proj_encoder_dim: 768
scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
diffusion_sampler:
class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler
init_args:
num_steps: 250
guidance: 1.0
timeshift: 1.0
state_refresh_rate: 1
guidance_interval_min: 0.3
guidance_interval_max: 1.0
scheduler: *scheduler
w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler
guidance_fn: src.diffusion.base.guidance.simple_guidance_fn
last_step: 0.04
step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
ema_tracker:
class_path: src.callbacks.simple_ema.SimpleEMA
init_args:
decay: 0.9999
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-3
betas:
- 0.9
- 0.95
weight_decay: 0.0
data:
train_dataset: hf_image
train_root: ./data/butterflies
test_nature_root: null
test_gen_root: null
train_image_size: 256
train_batch_size: 128
train_num_workers: 8
train_prefetch_factor: 2
train_hf_name: huggan/smithsonian_butterflies_subset
train_hf_split: train
train_hf_image_column: image
train_hf_label_column: null
train_hf_cache_dir: null
eval_max_num_instances: 256
pred_batch_size: 32
pred_num_workers: 4
pred_seeds: null
pred_selected_classes: null
num_classes: 1
latent_shape:
- 4
- 32
- 32

View File

@@ -7,3 +7,4 @@ torchvision
timm timm
accelerate accelerate
gradio gradio
datasets

View File

@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
var_transform_engine: VARTransformEngine = None, var_transform_engine: VARTransformEngine = None,
train_prefetch_factor=2, train_prefetch_factor=2,
train_dataset: str = None, train_dataset: str = None,
train_hf_name: str = None,
train_hf_split: str = "train",
train_hf_image_column: str = "image",
train_hf_label_column: str = None,
train_hf_cache_dir: str = None,
eval_batch_size=32, eval_batch_size=32,
eval_num_workers=4, eval_num_workers=4,
eval_max_num_instances=50000, eval_max_num_instances=50000,
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
self.train_root = train_root self.train_root = train_root
self.train_image_size = train_image_size self.train_image_size = train_image_size
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.train_hf_name = train_hf_name
self.train_hf_split = train_hf_split
self.train_hf_image_column = train_hf_image_column
self.train_hf_label_column = train_hf_label_column
self.train_hf_cache_dir = train_hf_cache_dir
# stupid data_convert override, just to make nebular happy # stupid data_convert override, just to make nebular happy
self.train_batch_size = train_batch_size self.train_batch_size = train_batch_size
self.train_num_workers = train_num_workers self.train_num_workers = train_num_workers
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
self.train_dataset = PixImageNet512( self.train_dataset = PixImageNet512(
root=self.train_root, root=self.train_root,
) )
elif self.train_dataset == "hf_image":
from src.data.dataset.hf_image import HuggingFaceImageDataset
if self.train_hf_name is None:
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
self.train_dataset = HuggingFaceImageDataset(
name=self.train_hf_name,
split=self.train_hf_split,
image_column=self.train_hf_image_column,
label_column=self.train_hf_label_column,
resolution=self.train_image_size,
cache_dir=self.train_hf_cache_dir,
)
else: else:
raise NotImplementedError("no such dataset") raise NotImplementedError("no such dataset")

Binary file not shown.