From 6bf32b08fdab1f06c85791bcd8ca7cfdcb1da53c Mon Sep 17 00:00:00 2001 From: game-loader Date: Sun, 18 Jan 2026 17:10:01 +0800 Subject: [PATCH] feat(dataset): add HuggingFace image dataset --- src/data/dataset/hf_image.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 src/data/dataset/hf_image.py diff --git a/src/data/dataset/hf_image.py b/src/data/dataset/hf_image.py new file mode 100644 index 0000000..198fcc6 --- /dev/null +++ b/src/data/dataset/hf_image.py @@ -0,0 +1,43 @@ +from typing import Optional + +import torch +from torch.utils.data import Dataset +from torchvision.transforms import Normalize +from torchvision.transforms.functional import to_tensor +from PIL import Image +from datasets import load_dataset + +from src.data.dataset.metric_dataset import CenterCrop + + +class HuggingFaceImageDataset(Dataset): + def __init__( + self, + name: str, + split: str = "train", + image_column: str = "image", + label_column: Optional[str] = None, + resolution: int = 256, + cache_dir: Optional[str] = None, + ) -> None: + self.dataset = load_dataset(name, split=split, cache_dir=cache_dir) + self.image_column = image_column + self.label_column = label_column + self.transform = CenterCrop(resolution) + self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int): + item = self.dataset[idx] + image = item[self.image_column] + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + image = image.convert("RGB") + image = self.transform(image) + + raw_image = to_tensor(image) + normalized_image = self.normalize(raw_image) + label = 0 if self.label_column is None else int(item[self.label_column]) + return raw_image, normalized_image, label