From cf45afe32569918b71769ca93ed467b4bca673ce Mon Sep 17 00:00:00 2001 From: game-loader Date: Fri, 16 Jan 2026 14:25:32 +0800 Subject: [PATCH] feat(data): add support for Hugging Face datasets --- configs/ddt_butterflies_b2_256.yaml | 116 ++++++++++++++++++ requirements.txt | 3 +- src/lightning_data.py | 22 ++++ .../__pycache__/patch_bugs.cpython-312.pyc | Bin 0 -> 945 bytes 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 configs/ddt_butterflies_b2_256.yaml create mode 100644 src/utils/__pycache__/patch_bugs.cpython-312.pyc diff --git a/configs/ddt_butterflies_b2_256.yaml b/configs/ddt_butterflies_b2_256.yaml new file mode 100644 index 0000000..84ccde9 --- /dev/null +++ b/configs/ddt_butterflies_b2_256.yaml @@ -0,0 +1,116 @@ +seed_everything: true +tags: + exp: &exp ddt_butterflies_b2_256 +torch_hub_dir: null +huggingface_cache_dir: null +trainer: + default_root_dir: workdirs + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: bf16-mixed + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: ddt_butterflies + name: *exp + save_dir: workdirs + num_sanity_val_steps: 0 + max_steps: 200000 + val_check_interval: 2000 + check_val_every_n_epoch: null + log_every_n_steps: 50 + deterministic: null + inference_mode: true + use_distributed_sampler: false + callbacks: + - class_path: src.callbacks.model_checkpoint.CheckpointHook + init_args: + every_n_train_steps: 10000 + save_top_k: -1 + save_last: true + - class_path: src.callbacks.save_images.SaveImagesHook + init_args: + save_dir: val + max_save_num: 64 +model: + vae: + class_path: src.models.vae.LatentVAE + init_args: + precompute: false + weight_path: stabilityai/sd-vae-ft-ema + denoiser: + class_path: src.models.denoiser.decoupled_improved_dit.DDT + init_args: + in_channels: 4 + patch_size: 2 + num_groups: 12 + hidden_size: &hidden_dim 768 + num_blocks: 12 + num_encoder_blocks: 8 + num_classes: 1 + conditioner: + class_path: src.models.conditioner.LabelConditioner + init_args: + null_class: 1 + diffusion_trainer: + class_path: src.diffusion.stateful_flow_matching.training_repa.REPATrainer + init_args: + lognorm_t: true + encoder_weight_path: dinov2_vitb14 + align_layer: 4 + proj_denoiser_dim: *hidden_dim + proj_hidden_dim: *hidden_dim + proj_encoder_dim: 768 + scheduler: &scheduler src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + diffusion_sampler: + class_path: src.diffusion.stateful_flow_matching.sampling.EulerSampler + init_args: + num_steps: 250 + guidance: 1.0 + timeshift: 1.0 + state_refresh_rate: 1 + guidance_interval_min: 0.3 + guidance_interval_max: 1.0 + scheduler: *scheduler + w_scheduler: src.diffusion.stateful_flow_matching.scheduling.LinearScheduler + guidance_fn: src.diffusion.base.guidance.simple_guidance_fn + last_step: 0.04 + step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn + ema_tracker: + class_path: src.callbacks.simple_ema.SimpleEMA + init_args: + decay: 0.9999 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1e-3 + betas: + - 0.9 + - 0.95 + weight_decay: 0.0 +data: + train_dataset: hf_image + train_root: ./data/butterflies + test_nature_root: null + test_gen_root: null + train_image_size: 256 + train_batch_size: 128 + train_num_workers: 8 + train_prefetch_factor: 2 + train_hf_name: huggan/smithsonian_butterflies_subset + train_hf_split: train + train_hf_image_column: image + train_hf_label_column: null + train_hf_cache_dir: null + eval_max_num_instances: 256 + pred_batch_size: 32 + pred_num_workers: 4 + pred_seeds: null + pred_selected_classes: null + num_classes: 1 + latent_shape: + - 4 + - 32 + - 32 diff --git a/requirements.txt b/requirements.txt index 78c6303..0350c53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ jsonargparse[signatures]>=4.27.7 torchvision timm accelerate -gradio \ No newline at end of file +gradio +datasets diff --git a/src/lightning_data.py b/src/lightning_data.py index 9f75a42..434eb54 100644 --- a/src/lightning_data.py +++ b/src/lightning_data.py @@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule): var_transform_engine: VARTransformEngine = None, train_prefetch_factor=2, train_dataset: str = None, + train_hf_name: str = None, + train_hf_split: str = "train", + train_hf_image_column: str = "image", + train_hf_label_column: str = None, + train_hf_cache_dir: str = None, eval_batch_size=32, eval_num_workers=4, eval_max_num_instances=50000, @@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule): self.train_root = train_root self.train_image_size = train_image_size self.train_dataset = train_dataset + self.train_hf_name = train_hf_name + self.train_hf_split = train_hf_split + self.train_hf_image_column = train_hf_image_column + self.train_hf_label_column = train_hf_label_column + self.train_hf_cache_dir = train_hf_cache_dir # stupid data_convert override, just to make nebular happy self.train_batch_size = train_batch_size self.train_num_workers = train_num_workers @@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule): self.train_dataset = PixImageNet512( root=self.train_root, ) + elif self.train_dataset == "hf_image": + from src.data.dataset.hf_image import HuggingFaceImageDataset + if self.train_hf_name is None: + raise ValueError("train_hf_name must be set when train_dataset=hf_image") + self.train_dataset = HuggingFaceImageDataset( + name=self.train_hf_name, + split=self.train_hf_split, + image_column=self.train_hf_image_column, + label_column=self.train_hf_label_column, + resolution=self.train_image_size, + cache_dir=self.train_hf_cache_dir, + ) else: raise NotImplementedError("no such dataset") diff --git a/src/utils/__pycache__/patch_bugs.cpython-312.pyc b/src/utils/__pycache__/patch_bugs.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16aaada261e8f1bb095e6db931865e83b16baac3 GIT binary patch literal 945 zcmZ`%&ubGw6n?wOX0ypAsbW;3SQ8OjkS>$@<6vJ{OcexUYMFyhy62Baz-)7DDW@~l=`EiJ00S+%lE zV-7}E+o8gi-bCv(x>kBsrHSg?oT`;8vnvafo2ECG=e4<|CGEvxRa0LsXliA(qSBaU zHeG@zc^cCV-*THYZs+hV8il)jsp@Cr;>?%bs$V*p%9HTriwSQ zO~qiT#p(j(U5|>mxoMGL+AW;~gLzc;ET?7T4a@d1p{e<{-zIo1%+lnFhsiK_59h%{ z{(hv?T?Z?OxE4~%YqhSK&e+f!P)Z`8H*w&}5jX}svTsBPmJl4rc&Lwftn{EBcK z_YLIlpqzm61H?VyIFmnodI}RgD0GA~D0HE4`1}+e_aWVra(lPVq+D0Z9Xvac9&}5GJWSq_R2TdR8!2kdN literal 0 HcmV?d00001