feat(dataset): add HuggingFace image dataset
This commit is contained in:
43
src/data/dataset/hf_image.py
Normal file
43
src/data/dataset/hf_image.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision.transforms import Normalize
|
||||||
|
from torchvision.transforms.functional import to_tensor
|
||||||
|
from PIL import Image
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from src.data.dataset.metric_dataset import CenterCrop
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceImageDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
split: str = "train",
|
||||||
|
image_column: str = "image",
|
||||||
|
label_column: Optional[str] = None,
|
||||||
|
resolution: int = 256,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
self.dataset = load_dataset(name, split=split, cache_dir=cache_dir)
|
||||||
|
self.image_column = image_column
|
||||||
|
self.label_column = label_column
|
||||||
|
self.transform = CenterCrop(resolution)
|
||||||
|
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.dataset)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
item = self.dataset[idx]
|
||||||
|
image = item[self.image_column]
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image = image.convert("RGB")
|
||||||
|
image = self.transform(image)
|
||||||
|
|
||||||
|
raw_image = to_tensor(image)
|
||||||
|
normalized_image = self.normalize(raw_image)
|
||||||
|
label = 0 if self.label_column is None else int(item[self.label_column])
|
||||||
|
return raw_image, normalized_image, label
|
||||||
Reference in New Issue
Block a user