submit code
This commit is contained in:
1
src/data/__init__.py
Normal file
1
src/data/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
0
src/data/dataset/__init__.py
Normal file
0
src/data/dataset/__init__.py
Normal file
11
src/data/dataset/celeba.py
Normal file
11
src/data/dataset/celeba.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Callable
|
||||
from torchvision.datasets import CelebA
|
||||
|
||||
|
||||
class LocalDataset(CelebA):
|
||||
def __init__(self, root:str, ):
|
||||
super(LocalDataset, self).__init__(root, "train")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data = super().__getitem__(idx)
|
||||
return data
|
||||
82
src/data/dataset/imagenet.py
Normal file
82
src/data/dataset/imagenet.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
from torchvision.transforms import Normalize
|
||||
|
||||
from src.data.dataset.metric_dataset import CenterCrop
|
||||
|
||||
class LocalCachedDataset(ImageFolder):
|
||||
def __init__(self, root, resolution=256):
|
||||
super().__init__(root)
|
||||
self.transform = CenterCrop(resolution)
|
||||
self.cache_root = None
|
||||
|
||||
def load_latent(self, latent_path):
|
||||
pk_data = torch.load(latent_path)
|
||||
mean = pk_data['mean'].to(torch.float32)
|
||||
logvar = pk_data['logvar'].to(torch.float32)
|
||||
logvar = torch.clamp(logvar, -30.0, 20.0)
|
||||
std = torch.exp(0.5 * logvar)
|
||||
latent = mean + torch.randn_like(mean) * std
|
||||
return latent
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
image_path, target = self.samples[idx]
|
||||
latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
|
||||
|
||||
raw_image = Image.open(image_path).convert('RGB')
|
||||
raw_image = self.transform(raw_image)
|
||||
raw_image = to_tensor(raw_image)
|
||||
if self.cache_root is not None:
|
||||
latent = self.load_latent(latent_path)
|
||||
else:
|
||||
latent = raw_image
|
||||
return raw_image, latent, target
|
||||
|
||||
class ImageNet256(LocalCachedDataset):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 256)
|
||||
self.cache_root = root + "_256_latent"
|
||||
|
||||
class ImageNet512(LocalCachedDataset):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 512)
|
||||
self.cache_root = root + "_512_latent"
|
||||
|
||||
class PixImageNet(ImageFolder):
|
||||
def __init__(self, root, resolution=256):
|
||||
super().__init__(root)
|
||||
self.transform = CenterCrop(resolution)
|
||||
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
image_path, target = self.samples[idx]
|
||||
raw_image = Image.open(image_path).convert('RGB')
|
||||
raw_image = self.transform(raw_image)
|
||||
raw_image = to_tensor(raw_image)
|
||||
|
||||
normalized_image = self.normalize(raw_image)
|
||||
return raw_image, normalized_image, target
|
||||
|
||||
class PixImageNet64(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 64)
|
||||
|
||||
class PixImageNet128(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 128)
|
||||
|
||||
|
||||
class PixImageNet256(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 256)
|
||||
|
||||
class PixImageNet512(PixImageNet):
|
||||
def __init__(self, root, ):
|
||||
super().__init__(root, 512)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
82
src/data/dataset/metric_dataset.py
Normal file
82
src/data/dataset/metric_dataset.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
|
||||
from PIL import Image
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
def test_collate(batch):
|
||||
return torch.stack(batch)
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, root, image_size=(224, 224)):
|
||||
self.root = pathlib.Path(root)
|
||||
images = []
|
||||
for ext in IMG_EXTENSIONS:
|
||||
images.extend(self.root.rglob(ext))
|
||||
random.shuffle(images)
|
||||
self.images = list(map(lambda x: str(x), images))
|
||||
self.transform = tvtf.Compose(
|
||||
[
|
||||
CenterCrop(image_size[0]),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
|
||||
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
|
||||
]
|
||||
)
|
||||
self.size = image_size
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
image = Image.open(self.images[idx])
|
||||
image = self.transform(image)
|
||||
except Exception as e:
|
||||
print(self.images[idx])
|
||||
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
|
||||
|
||||
# print(image)
|
||||
metadata = dict(
|
||||
path = self.images[idx],
|
||||
root = self.root,
|
||||
)
|
||||
return image #, metadata
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
41
src/data/dataset/randn.py
Normal file
41
src/data/dataset/randn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os.path
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
|
||||
class RandomNDataset(Dataset):
|
||||
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
|
||||
self.selected_classes = selected_classes
|
||||
if selected_classes is not None:
|
||||
num_classes = len(selected_classes)
|
||||
max_num_instances = 10*num_classes
|
||||
self.num_classes = num_classes
|
||||
self.seeds = seeds
|
||||
if seeds is not None:
|
||||
self.max_num_instances = len(seeds)*num_classes
|
||||
self.num_seeds = len(seeds)
|
||||
else:
|
||||
self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
|
||||
self.max_num_instances = self.num_seeds*num_classes
|
||||
|
||||
self.latent_shape = latent_shape
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
label = idx // self.num_seeds
|
||||
if self.selected_classes:
|
||||
label = self.selected_classes[label]
|
||||
seed = random.randint(0, 1<<31) #idx % self.num_seeds
|
||||
if self.seeds is not None:
|
||||
seed = self.seeds[idx % self.num_seeds]
|
||||
|
||||
# cls_dir = os.path.join(self.root, f"{label}")
|
||||
filename = f"{label}_{seed}.png",
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
|
||||
return latent, label, filename
|
||||
def __len__(self):
|
||||
return self.max_num_instances
|
||||
145
src/data/var_training.py
Normal file
145
src/data/var_training.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import torch
|
||||
from typing import Callable
|
||||
from src.diffusion.base.training import *
|
||||
from src.diffusion.base.scheduling import BaseScheduler
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
import copy
|
||||
import torchvision.transforms.functional as tvtf
|
||||
from src.models.vae import uint82fp
|
||||
|
||||
|
||||
def center_crop_arr(pil_image, width, height):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = max(width / pil_image.size[0], height / pil_image.size[1])
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
arr = np.array(pil_image)
|
||||
crop_y = random.randint(0, (arr.shape[0] - height))
|
||||
crop_x = random.randint(0, (arr.shape[1] - width))
|
||||
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
|
||||
|
||||
def process_fn(width, height, data, hflip=0.5):
|
||||
image, label = data
|
||||
if random.uniform(0, 1) > hflip: # hflip
|
||||
image = tvtf.hflip(image)
|
||||
image = center_crop_arr(image, width, height) # crop
|
||||
image = np.array(image).transpose(2, 0, 1)
|
||||
return image, label
|
||||
|
||||
class VARCandidate:
|
||||
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
|
||||
self.aspect_ratio = aspect_ratio
|
||||
self.width = int(width)
|
||||
self.height = int(height)
|
||||
self.buffer = buffer
|
||||
self.max_buffer_size = max_buffer_size
|
||||
|
||||
def add_sample(self, data):
|
||||
self.buffer.append(data)
|
||||
self.buffer = self.buffer[-self.max_buffer_size:]
|
||||
|
||||
def ready(self, batch_size):
|
||||
return len(self.buffer) >= batch_size
|
||||
|
||||
def get_batch(self, batch_size):
|
||||
batch = self.buffer[:batch_size]
|
||||
self.buffer = self.buffer[batch_size:]
|
||||
batch = [copy.deepcopy(b.result()) for b in batch]
|
||||
x, y = zip(*batch)
|
||||
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
|
||||
x = list(map(uint82fp, x))
|
||||
return x, y
|
||||
|
||||
class VARTransformEngine:
|
||||
def __init__(self,
|
||||
base_image_size,
|
||||
num_aspect_ratios,
|
||||
min_aspect_ratio,
|
||||
max_aspect_ratio,
|
||||
num_workers = 8,
|
||||
):
|
||||
self.base_image_size = base_image_size
|
||||
self.num_aspect_ratios = num_aspect_ratios
|
||||
self.min_aspect_ratio = min_aspect_ratio
|
||||
self.max_aspect_ratio = max_aspect_ratio
|
||||
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
|
||||
self.aspect_ratios = self.aspect_ratios.tolist()
|
||||
self.candidates_pool = []
|
||||
for i in range(self.num_aspect_ratios):
|
||||
candidate = VARCandidate(
|
||||
aspect_ratio=self.aspect_ratios[i],
|
||||
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
|
||||
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
|
||||
buffer=[],
|
||||
max_buffer_size=1024
|
||||
)
|
||||
self.candidates_pool.append(candidate)
|
||||
self.default_candidate = VARCandidate(
|
||||
aspect_ratio=1.0,
|
||||
width=self.base_image_size,
|
||||
height=self.base_image_size,
|
||||
buffer=[],
|
||||
max_buffer_size=1024,
|
||||
)
|
||||
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
|
||||
self._prefill_count = 100
|
||||
|
||||
def find_candidate(self, data):
|
||||
image = data[0]
|
||||
aspect_ratio = image.size[0] / image.size[1]
|
||||
min_distance = 1000000
|
||||
min_candidate = None
|
||||
for candidate in self.candidates_pool:
|
||||
dis = abs(aspect_ratio - candidate.aspect_ratio)
|
||||
if dis < min_distance:
|
||||
min_distance = dis
|
||||
min_candidate = candidate
|
||||
return min_candidate
|
||||
|
||||
|
||||
def __call__(self, batch_data):
|
||||
self._prefill_count -= 1
|
||||
if isinstance(batch_data[0], torch.Tensor):
|
||||
batch_data[0] = batch_data[0].unbind(0)
|
||||
|
||||
batch_data = list(zip(*batch_data))
|
||||
for data in batch_data:
|
||||
candidate = self.find_candidate(data)
|
||||
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
|
||||
candidate.add_sample(future)
|
||||
if self._prefill_count >= 0:
|
||||
future = self.executor_pool.submit(process_fn,
|
||||
self.default_candidate.width,
|
||||
self.default_candidate.height,
|
||||
data)
|
||||
self.default_candidate.add_sample(future)
|
||||
|
||||
batch_size = len(batch_data)
|
||||
random.shuffle(self.candidates_pool)
|
||||
for candidate in self.candidates_pool:
|
||||
if candidate.ready(batch_size=batch_size):
|
||||
return candidate.get_batch(batch_size=batch_size)
|
||||
|
||||
# fallback to default 256
|
||||
for data in batch_data:
|
||||
future = self.executor_pool.submit(process_fn,
|
||||
self.default_candidate.width,
|
||||
self.default_candidate.height,
|
||||
data)
|
||||
self.default_candidate.add_sample(future)
|
||||
return self.default_candidate.get_batch(batch_size=batch_size)
|
||||
Reference in New Issue
Block a user