submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

1
src/data/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

View File

@@ -0,0 +1,11 @@
from typing import Callable
from torchvision.datasets import CelebA
class LocalDataset(CelebA):
def __init__(self, root:str, ):
super(LocalDataset, self).__init__(root, "train")
def __getitem__(self, idx):
data = super().__getitem__(idx)
return data

View File

@@ -0,0 +1,82 @@
import torch
from PIL import Image
from torchvision.datasets import ImageFolder
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import Normalize
from src.data.dataset.metric_dataset import CenterCrop
class LocalCachedDataset(ImageFolder):
def __init__(self, root, resolution=256):
super().__init__(root)
self.transform = CenterCrop(resolution)
self.cache_root = None
def load_latent(self, latent_path):
pk_data = torch.load(latent_path)
mean = pk_data['mean'].to(torch.float32)
logvar = pk_data['logvar'].to(torch.float32)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
latent = mean + torch.randn_like(mean) * std
return latent
def __getitem__(self, idx: int):
image_path, target = self.samples[idx]
latent_path = image_path.replace(self.root, self.cache_root) + ".pt"
raw_image = Image.open(image_path).convert('RGB')
raw_image = self.transform(raw_image)
raw_image = to_tensor(raw_image)
if self.cache_root is not None:
latent = self.load_latent(latent_path)
else:
latent = raw_image
return raw_image, latent, target
class ImageNet256(LocalCachedDataset):
def __init__(self, root, ):
super().__init__(root, 256)
self.cache_root = root + "_256_latent"
class ImageNet512(LocalCachedDataset):
def __init__(self, root, ):
super().__init__(root, 512)
self.cache_root = root + "_512_latent"
class PixImageNet(ImageFolder):
def __init__(self, root, resolution=256):
super().__init__(root)
self.transform = CenterCrop(resolution)
self.normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
def __getitem__(self, idx: int):
image_path, target = self.samples[idx]
raw_image = Image.open(image_path).convert('RGB')
raw_image = self.transform(raw_image)
raw_image = to_tensor(raw_image)
normalized_image = self.normalize(raw_image)
return raw_image, normalized_image, target
class PixImageNet64(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 64)
class PixImageNet128(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 128)
class PixImageNet256(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 256)
class PixImageNet512(PixImageNet):
def __init__(self, root, ):
super().__init__(root, 512)

View File

@@ -0,0 +1,82 @@
import pathlib
import torch
import random
import numpy as np
from torchvision.io.image import read_image
import torchvision.transforms as tvtf
from torch.utils.data import Dataset
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
from PIL import Image
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
def test_collate(batch):
return torch.stack(batch)
class ImageDataset(Dataset):
def __init__(self, root, image_size=(224, 224)):
self.root = pathlib.Path(root)
images = []
for ext in IMG_EXTENSIONS:
images.extend(self.root.rglob(ext))
random.shuffle(images)
self.images = list(map(lambda x: str(x), images))
self.transform = tvtf.Compose(
[
CenterCrop(image_size[0]),
tvtf.ToTensor(),
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
]
)
self.size = image_size
def __getitem__(self, idx):
try:
image = Image.open(self.images[idx])
image = self.transform(image)
except Exception as e:
print(self.images[idx])
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
# print(image)
metadata = dict(
path = self.images[idx],
root = self.root,
)
return image #, metadata
def __len__(self):
return len(self.images)

41
src/data/dataset/randn.py Normal file
View File

@@ -0,0 +1,41 @@
import os.path
import random
import torch
from torch.utils.data import Dataset
class RandomNDataset(Dataset):
def __init__(self, latent_shape=(4, 64, 64), num_classes=1000, selected_classes:list=None, seeds=None, max_num_instances=50000, ):
self.selected_classes = selected_classes
if selected_classes is not None:
num_classes = len(selected_classes)
max_num_instances = 10*num_classes
self.num_classes = num_classes
self.seeds = seeds
if seeds is not None:
self.max_num_instances = len(seeds)*num_classes
self.num_seeds = len(seeds)
else:
self.num_seeds = (max_num_instances + num_classes - 1) // num_classes
self.max_num_instances = self.num_seeds*num_classes
self.latent_shape = latent_shape
def __getitem__(self, idx):
label = idx // self.num_seeds
if self.selected_classes:
label = self.selected_classes[label]
seed = random.randint(0, 1<<31) #idx % self.num_seeds
if self.seeds is not None:
seed = self.seeds[idx % self.num_seeds]
# cls_dir = os.path.join(self.root, f"{label}")
filename = f"{label}_{seed}.png",
generator = torch.Generator().manual_seed(seed)
latent = torch.randn(self.latent_shape, generator=generator, dtype=torch.float32)
return latent, label, filename
def __len__(self):
return self.max_num_instances

145
src/data/var_training.py Normal file
View File

@@ -0,0 +1,145 @@
import torch
from typing import Callable
from src.diffusion.base.training import *
from src.diffusion.base.scheduling import BaseScheduler
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor
from typing import List
from PIL import Image
import torch
import random
import numpy as np
import copy
import torchvision.transforms.functional as tvtf
from src.models.vae import uint82fp
def center_crop_arr(pil_image, width, height):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while pil_image.size[0] >= 2 * width and pil_image.size[1] >= 2 * height:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = max(width / pil_image.size[0], height / pil_image.size[1])
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = random.randint(0, (arr.shape[0] - height))
crop_x = random.randint(0, (arr.shape[1] - width))
return Image.fromarray(arr[crop_y: crop_y + height, crop_x: crop_x + width])
def process_fn(width, height, data, hflip=0.5):
image, label = data
if random.uniform(0, 1) > hflip: # hflip
image = tvtf.hflip(image)
image = center_crop_arr(image, width, height) # crop
image = np.array(image).transpose(2, 0, 1)
return image, label
class VARCandidate:
def __init__(self, aspect_ratio, width, height, buffer, max_buffer_size=1024):
self.aspect_ratio = aspect_ratio
self.width = int(width)
self.height = int(height)
self.buffer = buffer
self.max_buffer_size = max_buffer_size
def add_sample(self, data):
self.buffer.append(data)
self.buffer = self.buffer[-self.max_buffer_size:]
def ready(self, batch_size):
return len(self.buffer) >= batch_size
def get_batch(self, batch_size):
batch = self.buffer[:batch_size]
self.buffer = self.buffer[batch_size:]
batch = [copy.deepcopy(b.result()) for b in batch]
x, y = zip(*batch)
x = torch.stack([torch.from_numpy(im).cuda() for im in x], dim=0)
x = list(map(uint82fp, x))
return x, y
class VARTransformEngine:
def __init__(self,
base_image_size,
num_aspect_ratios,
min_aspect_ratio,
max_aspect_ratio,
num_workers = 8,
):
self.base_image_size = base_image_size
self.num_aspect_ratios = num_aspect_ratios
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.aspect_ratios = np.linspace(self.min_aspect_ratio, self.max_aspect_ratio, self.num_aspect_ratios)
self.aspect_ratios = self.aspect_ratios.tolist()
self.candidates_pool = []
for i in range(self.num_aspect_ratios):
candidate = VARCandidate(
aspect_ratio=self.aspect_ratios[i],
width=int(self.base_image_size * self.aspect_ratios[i] ** 0.5 // 16 * 16),
height=int(self.base_image_size * self.aspect_ratios[i] ** -0.5 // 16 * 16),
buffer=[],
max_buffer_size=1024
)
self.candidates_pool.append(candidate)
self.default_candidate = VARCandidate(
aspect_ratio=1.0,
width=self.base_image_size,
height=self.base_image_size,
buffer=[],
max_buffer_size=1024,
)
self.executor_pool = ProcessPoolExecutor(max_workers=num_workers)
self._prefill_count = 100
def find_candidate(self, data):
image = data[0]
aspect_ratio = image.size[0] / image.size[1]
min_distance = 1000000
min_candidate = None
for candidate in self.candidates_pool:
dis = abs(aspect_ratio - candidate.aspect_ratio)
if dis < min_distance:
min_distance = dis
min_candidate = candidate
return min_candidate
def __call__(self, batch_data):
self._prefill_count -= 1
if isinstance(batch_data[0], torch.Tensor):
batch_data[0] = batch_data[0].unbind(0)
batch_data = list(zip(*batch_data))
for data in batch_data:
candidate = self.find_candidate(data)
future = self.executor_pool.submit(process_fn, candidate.width, candidate.height, data)
candidate.add_sample(future)
if self._prefill_count >= 0:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
batch_size = len(batch_data)
random.shuffle(self.candidates_pool)
for candidate in self.candidates_pool:
if candidate.ready(batch_size=batch_size):
return candidate.get_batch(batch_size=batch_size)
# fallback to default 256
for data in batch_data:
future = self.executor_pool.submit(process_fn,
self.default_candidate.width,
self.default_candidate.height,
data)
self.default_candidate.add_sample(future)
return self.default_candidate.get_batch(batch_size=batch_size)