feat(dataset): add HuggingFace image dataset
This commit is contained in:
43
src/data/dataset/hf_image.py
Normal file
43
src/data/dataset/hf_image.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import Normalize
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
from PIL import Image
|
||||
from datasets import load_dataset
|
||||
|
||||
from src.data.dataset.metric_dataset import CenterCrop
|
||||
|
||||
|
||||
class HuggingFaceImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
split: str = "train",
|
||||
image_column: str = "image",
|
||||
label_column: Optional[str] = None,
|
||||
resolution: int = 256,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
|
||||
self.image_column = image_column
|
||||
self.label_column = label_column
|
||||
self.transform = CenterCrop(resolution)
|
||||
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
item = self.dataset[idx]
|
||||
image = item[self.image_column]
|
||||
if not isinstance(image, Image.Image):
|
||||
image = Image.fromarray(image)
|
||||
image = image.convert("RGB")
|
||||
image = self.transform(image)
|
||||
|
||||
raw_image = to_tensor(image)
|
||||
normalized_image = self.normalize(raw_image)
|
||||
label = 0 if self.label_column is None else int(item[self.label_column])
|
||||
return raw_image, normalized_image, label
|
||||
Reference in New Issue
Block a user