feat(dataset): add HuggingFace image dataset

This commit is contained in:
game-loader
2026-01-18 17:10:01 +08:00
parent cf45afe325
commit 6bf32b08fd

View File

@@ -0,0 +1,43 @@
from typing import Optional
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Normalize
from torchvision.transforms.functional import to_tensor
from PIL import Image
from datasets import load_dataset
from src.data.dataset.metric_dataset import CenterCrop
class HuggingFaceImageDataset(Dataset):
def __init__(
self,
name: str,
split: str = "train",
image_column: str = "image",
label_column: Optional[str] = None,
resolution: int = 256,
cache_dir: Optional[str] = None,
) -> None:
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
self.image_column = image_column
self.label_column = label_column
self.transform = CenterCrop(resolution)
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int):
item = self.dataset[idx]
image = item[self.image_column]
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert("RGB")
image = self.transform(image)
raw_image = to_tensor(image)
normalized_image = self.normalize(raw_image)
label = 0 if self.label_column is None else int(item[self.label_column])
return raw_image, normalized_image, label