feat(data): add support for Hugging Face datasets

This commit is contained in:
game-loader
2026-01-16 14:25:32 +08:00
parent 1d1b4d2913
commit cf45afe325
4 changed files with 140 additions and 1 deletions

View File

@@ -29,6 +29,11 @@ class DataModule(pl.LightningDataModule):
var_transform_engine: VARTransformEngine = None,
train_prefetch_factor=2,
train_dataset: str = None,
train_hf_name: str = None,
train_hf_split: str = "train",
train_hf_image_column: str = "image",
train_hf_label_column: str = None,
train_hf_cache_dir: str = None,
eval_batch_size=32,
eval_num_workers=4,
eval_max_num_instances=50000,
@@ -45,6 +50,11 @@ class DataModule(pl.LightningDataModule):
self.train_root = train_root
self.train_image_size = train_image_size
self.train_dataset = train_dataset
self.train_hf_name = train_hf_name
self.train_hf_split = train_hf_split
self.train_hf_image_column = train_hf_image_column
self.train_hf_label_column = train_hf_label_column
self.train_hf_cache_dir = train_hf_cache_dir
# stupid data_convert override, just to make nebular happy
self.train_batch_size = train_batch_size
self.train_num_workers = train_num_workers
@@ -101,6 +111,18 @@ class DataModule(pl.LightningDataModule):
self.train_dataset = PixImageNet512(
root=self.train_root,
)
elif self.train_dataset == "hf_image":
from src.data.dataset.hf_image import HuggingFaceImageDataset
if self.train_hf_name is None:
raise ValueError("train_hf_name must be set when train_dataset=hf_image")
self.train_dataset = HuggingFaceImageDataset(
name=self.train_hf_name,
split=self.train_hf_split,
image_column=self.train_hf_image_column,
label_column=self.train_hf_label_column,
resolution=self.train_image_size,
cache_dir=self.train_hf_cache_dir,
)
else:
raise NotImplementedError("no such dataset")

Binary file not shown.