submit code
This commit is contained in:
117
tools/cache_imlatent3.py
Normal file
117
tools/cache_imlatent3.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
import torch
|
||||
from typing import Callable
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
writer_pool = ThreadPoolExecutor(8)
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
# tvtf.RandomHorizontalFlip(p=1),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
# dataset = ImageNet(root='/tmp', split="train", transform=transforms, )
|
||||
B = 256
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16)
|
||||
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
vae, dataloader = accelerator.prepare(vae, dataloader)
|
||||
rank = accelerator.process_index
|
||||
with torch.no_grad():
|
||||
for i, (image, label, path_list) in enumerate(dataloader):
|
||||
# if i >= 128: break
|
||||
new_path_list = []
|
||||
for p in path_list:
|
||||
p = p + ".pt"
|
||||
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
|
||||
"/mnt/bn/wangshuai6/data/ImageNet/train_256latent")
|
||||
new_path_list.append(p)
|
||||
|
||||
image = image.to("cuda")
|
||||
distribution = vae.module.encode(image).latent_dist
|
||||
mean = distribution.mean
|
||||
logvar = distribution.logvar
|
||||
for j in range(B):
|
||||
out = dict(
|
||||
mean=mean[j].cpu(),
|
||||
logvar=logvar[j].cpu(),
|
||||
)
|
||||
writer_pool.submit(save, out, new_path_list[j])
|
||||
writer_pool.shutdown(wait=True)
|
||||
accelerator.wait_for_everyone()
|
||||
123
tools/cache_imlatent4.py
Normal file
123
tools/cache_imlatent4.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
import torch
|
||||
from typing import Callable
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
writer_pool = ThreadPoolExecutor(8)
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(512),
|
||||
# tvtf.RandomHorizontalFlip(p=1),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 8
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16)
|
||||
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
|
||||
vae = vae.to(torch.float16)
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
vae, dataloader = accelerator.prepare(vae, dataloader)
|
||||
rank = accelerator.process_index
|
||||
with torch.no_grad():
|
||||
for i, (image, label, path_list) in enumerate(dataloader):
|
||||
print(i/len(dataloader))
|
||||
flag = False
|
||||
new_path_list = []
|
||||
for p in path_list:
|
||||
p = p + ".pt"
|
||||
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
|
||||
"/mnt/bn/wangshuai6/data/ImageNet/train_512_latent")
|
||||
new_path_list.append(p)
|
||||
if not os.path.exists(p):
|
||||
print(p)
|
||||
flag = True
|
||||
|
||||
if flag:
|
||||
image = image.to("cuda")
|
||||
image = image.to(torch.float16)
|
||||
distribution = vae.module.encode(image).latent_dist
|
||||
mean = distribution.mean
|
||||
logvar = distribution.logvar
|
||||
|
||||
for j in range(len(path_list)):
|
||||
out = dict(
|
||||
mean=mean[j].cpu(),
|
||||
logvar=logvar[j].cpu(),
|
||||
)
|
||||
writer_pool.submit(save, out, new_path_list[j])
|
||||
writer_pool.shutdown(wait=True)
|
||||
accelerator.wait_for_everyone()
|
||||
43
tools/cat_images.py
Normal file
43
tools/cat_images.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import pathlib
|
||||
import argparse
|
||||
|
||||
def group_images(path_list):
|
||||
sorted(path_list)
|
||||
class_id_dict = {}
|
||||
for path in path_list:
|
||||
class_id = str(path.name).split('_')[0]
|
||||
if class_id not in class_id_dict:
|
||||
class_id_dict[class_id] = []
|
||||
class_id_dict[class_id].append(path)
|
||||
return class_id_dict
|
||||
|
||||
def cat_images(path_list):
|
||||
imgs = []
|
||||
for path in path_list:
|
||||
img = cv2.imread(str(path))
|
||||
os.remove(path)
|
||||
imgs.append(img)
|
||||
row_cat_images = []
|
||||
row_length = int(len(imgs)**0.5)
|
||||
for i in range(len(imgs)//row_length):
|
||||
row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
|
||||
cat_image = np.concatenate(row_cat_images, axis=0)
|
||||
return cat_image
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir', type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
src_dir = args.src_dir
|
||||
path_list = list(pathlib.Path(src_dir).glob('*.png'))
|
||||
class_id_dict = group_images(path_list)
|
||||
for class_id, path_list in class_id_dict.items():
|
||||
cat_image = cat_images(path_list)
|
||||
cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
|
||||
# cat_path = "cat_{}.png".format(class_id)
|
||||
cv2.imwrite(cat_path, cat_image)
|
||||
|
||||
353
tools/classifer_training.py
Normal file
353
tools/classifer_training.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
# NORMALIZE_DATA = dict(
|
||||
# dinov2_vits14a = dict(
|
||||
# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79,
|
||||
# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49,
|
||||
# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67,
|
||||
# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12,
|
||||
# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43,
|
||||
# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61,
|
||||
# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23,
|
||||
# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57,
|
||||
# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65,
|
||||
# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13,
|
||||
# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27,
|
||||
# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94,
|
||||
# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07,
|
||||
# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56,
|
||||
# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50,
|
||||
# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29,
|
||||
# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78,
|
||||
# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03,
|
||||
# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02,
|
||||
# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25,
|
||||
# ],
|
||||
# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92,
|
||||
# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90,
|
||||
# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03,
|
||||
# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21,
|
||||
# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04,
|
||||
# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68,
|
||||
# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71,
|
||||
# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88,
|
||||
# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86,
|
||||
# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44,
|
||||
# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15,
|
||||
# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49,
|
||||
# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90,
|
||||
# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87,
|
||||
# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81,
|
||||
# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89,
|
||||
# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97,
|
||||
# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96,
|
||||
# 2.62,3.23,2.13,2.29,2.24,2.72
|
||||
# ]
|
||||
# ),
|
||||
# dinov2_vitb14a = dict(
|
||||
# mean=[
|
||||
# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82,
|
||||
# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17,
|
||||
# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31,
|
||||
# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15,
|
||||
# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03,
|
||||
# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54,
|
||||
# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43,
|
||||
# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26
|
||||
# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05,
|
||||
# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19,
|
||||
# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36,
|
||||
# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00,
|
||||
# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14,
|
||||
# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44,
|
||||
# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24,
|
||||
# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25,
|
||||
# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12,
|
||||
# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36,
|
||||
# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10,
|
||||
# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22,
|
||||
# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28,
|
||||
# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23,
|
||||
# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62,
|
||||
# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26,
|
||||
# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12,
|
||||
# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07,
|
||||
# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57,
|
||||
# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62,
|
||||
# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10,
|
||||
# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51,
|
||||
# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22,
|
||||
# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36,
|
||||
# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10,
|
||||
# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02,
|
||||
# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42,
|
||||
# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12,
|
||||
# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60,
|
||||
# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35,
|
||||
# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64,
|
||||
# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31
|
||||
# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00,
|
||||
# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73,
|
||||
# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05,
|
||||
# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35,
|
||||
# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05,
|
||||
# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54,
|
||||
# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03,
|
||||
# ],
|
||||
# std=[
|
||||
# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40,
|
||||
# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39,
|
||||
# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56,
|
||||
# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73,
|
||||
# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49,
|
||||
# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42,
|
||||
# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47,
|
||||
# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49,
|
||||
# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05,
|
||||
# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46,
|
||||
# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67,
|
||||
# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91,
|
||||
# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48,
|
||||
# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76,
|
||||
# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37,
|
||||
# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52,
|
||||
# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42,
|
||||
# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56,
|
||||
# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43,
|
||||
# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47,
|
||||
# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54,
|
||||
# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49,
|
||||
# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59,
|
||||
# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52,
|
||||
# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55,
|
||||
# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55,
|
||||
# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38,
|
||||
# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38,
|
||||
# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53,
|
||||
# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49,
|
||||
# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41,
|
||||
# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43,
|
||||
# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68,
|
||||
# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47,
|
||||
# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43,
|
||||
# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54,
|
||||
# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44,
|
||||
# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69,
|
||||
# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50,
|
||||
# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33,
|
||||
# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58,
|
||||
# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47,
|
||||
# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
|
||||
class DINOv2a(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2a, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
# self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False)
|
||||
# self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False)
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
# feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1)
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2)
|
||||
return feature
|
||||
|
||||
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
import math
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class Classifer(nn.Module):
|
||||
def __init__(self, in_channels=192, hidden_size=256, num_classes=1000):
|
||||
super(Classifer, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feature_x = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
def forward(self, xt):
|
||||
xt = xt[:, :self.in_channels]
|
||||
score = self.feature_x(xt).squeeze(-1).squeeze(-1)
|
||||
# score = (feature_xt).clamp(-5, 5)
|
||||
score = torch.softmax(score, dim=1)
|
||||
return score
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 64
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True)
|
||||
dino = DINOv2a("dinov2_vitb14")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
rank = accelerator.process_index
|
||||
|
||||
classifer = Classifer(in_channels=32)
|
||||
classifer.train()
|
||||
optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001)
|
||||
|
||||
dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer)
|
||||
|
||||
# fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance"
|
||||
# fake_file_names = os.listdir(fake_file_dir)
|
||||
|
||||
for epoch in range(100):
|
||||
for i, (true_images, true_labels, path_list) in enumerate(dataloader):
|
||||
batch_size = true_images.shape[0]
|
||||
true_labels = true_labels.to(accelerator.device)
|
||||
true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000)
|
||||
with torch.no_grad():
|
||||
true_dino_feature = dino(true_images)
|
||||
# t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device)
|
||||
# true_x_t = t * true_dino_feature + (1-t) * noise
|
||||
|
||||
true_x_t = true_dino_feature
|
||||
true_score = classifer(true_x_t)
|
||||
|
||||
# ind = i % len(fake_file_names)
|
||||
# fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind]))
|
||||
# import pdb; pdb.set_trace()
|
||||
# ind = torch.randint(0, 50, size=(4,))
|
||||
# fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :]
|
||||
# fake_labels = fake_file['condition'].repeat(4)
|
||||
# fake_score = classifer(fake_x_t)
|
||||
|
||||
loss_true = -torch.log(true_score)*true_labels
|
||||
loss = loss_true.sum()/batch_size
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size
|
||||
if accelerator.is_main_process:
|
||||
print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item()))
|
||||
if accelerator.is_main_process:
|
||||
torch.save(classifer.state_dict(), f'{epoch}.pth')
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
4
tools/debug_env.sh
Normal file
4
tools/debug_env.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat
|
||||
pip3 install -r requirements.txt
|
||||
git branch --set-upstream-to=origin/master master
|
||||
git pull
|
||||
173
tools/dino_scale.py
Normal file
173
tools/dino_scale.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
# tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 4096
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
dino = DINOv2("dinov2_vitb14")
|
||||
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
acc_mean = torch.zeros((768, ), device=accelerator.device)
|
||||
acc_num = 0
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
acc_num += len(images)
|
||||
feature = dino(images)
|
||||
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
|
||||
for std in stds:
|
||||
print("{:.2f},".format(std), end='')
|
||||
print()
|
||||
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
|
||||
for mean in means:
|
||||
print("{:.2f},".format(mean), end='')
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
168
tools/dino_scale2.py
Normal file
168
tools/dino_scale2.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 2048
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
dino = DINOv2("dinov2_vitb14")
|
||||
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
feature = dino(images)
|
||||
b, c, h, w = feature.shape
|
||||
feature = feature.view(b, c, h*w).transpose(1, 2)
|
||||
feature = feature.reshape(-1, c)
|
||||
U, S, V = torch.pca_lowrank(feature, 64, )
|
||||
import pdb; pdb.set_trace()
|
||||
feature = torch.matmul(feature, V)
|
||||
break
|
||||
accelerator.wait_for_everyone()
|
||||
64
tools/dp.py
Normal file
64
tools/dp.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import matplotlib.pyplot as plt
|
||||
print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250]))
|
||||
print(len(list(range(0, 251, 5))))
|
||||
exit()
|
||||
plt.plot()
|
||||
plt.plot()
|
||||
plt.show()
|
||||
exit()
|
||||
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
num_steps = 10
|
||||
num_recompute_timesteps = 4
|
||||
sim = torch.randint(0, 100, (num_steps, num_steps))
|
||||
sim[:5, :5] = 100
|
||||
for i in range(num_steps):
|
||||
sim[i, i] = 100
|
||||
|
||||
error_map = (100-sim).tolist()
|
||||
|
||||
|
||||
# init
|
||||
for i in range(1, num_steps):
|
||||
for j in range(0, i):
|
||||
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
|
||||
|
||||
C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
|
||||
P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
|
||||
|
||||
for i in range(1, num_steps+1):
|
||||
C[1][i] = error_map[i-1][0]
|
||||
P[1][i] = 0
|
||||
|
||||
|
||||
# dp
|
||||
for step in range(2, num_recompute_timesteps+1):
|
||||
for i in range(step, num_steps+1):
|
||||
min_value = 99999
|
||||
min_index = -1
|
||||
for j in range(step-1, i):
|
||||
value = C[step-1][j] + error_map[i-1][j]
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
min_index = j
|
||||
C[step][i] = min_value
|
||||
P[step][i] = min_index
|
||||
|
||||
# trace back
|
||||
tracback_end_index = num_steps
|
||||
# min_value = 99999
|
||||
# for i in range(num_recompute_timesteps-1, num_steps):
|
||||
# if C[-1][i] < min_value:
|
||||
# min_value = C[-1][i]
|
||||
# tracback_end_index = i
|
||||
|
||||
timesteps = [tracback_end_index, ]
|
||||
for i in range(num_recompute_timesteps, 0, -1):
|
||||
idx = timesteps[-1]
|
||||
timesteps.append(P[i][idx])
|
||||
timesteps.reverse()
|
||||
print(timesteps)
|
||||
64
tools/figures/base++.py
Normal file
64
tools/figures/base++.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
|
||||
"6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
|
||||
"8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
|
||||
"10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
|
||||
"6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
|
||||
"8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
|
||||
"10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
|
||||
"6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
|
||||
"8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
|
||||
"10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
|
||||
"6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
|
||||
"8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
|
||||
"10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
|
||||
"6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
|
||||
"8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
|
||||
"10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
|
||||
plt.legend(fontsize="14")
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
|
||||
plt.close()
|
||||
57
tools/figures/base.py
Normal file
57
tools/figures/base.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
|
||||
"6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
|
||||
"8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
|
||||
"6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
|
||||
"8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
|
||||
}
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
|
||||
"6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
|
||||
"8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
|
||||
"6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
|
||||
"8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
|
||||
"6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
|
||||
"8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
metric_data = {
|
||||
"FID" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
|
||||
plt.legend()
|
||||
plt.xticks(x)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
32
tools/figures/cfg.py
Normal file
32
tools/figures/cfg.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
cfg_data = {
|
||||
"[0, 1]":{
|
||||
"cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
|
||||
"FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
|
||||
},
|
||||
"[0.2, 1]":{
|
||||
"cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
|
||||
"FID": [5.87, 4.44, 3.96, 4.01, 4.26]
|
||||
},
|
||||
"[0.3, 1]":{
|
||||
"cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
|
||||
"FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
|
||||
},
|
||||
"[0.35, 1]":{
|
||||
"cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
|
||||
"FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
|
||||
}
|
||||
}
|
||||
|
||||
colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
for i, (name, data) in enumerate(cfg_data.items()):
|
||||
plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
|
||||
|
||||
plt.title("Classifer-free guidance with intervals", weight="bold")
|
||||
plt.ylabel("FID10K", weight="bold")
|
||||
plt.xlabel("CFG values", weight="bold")
|
||||
plt.legend()
|
||||
plt.savefig("./output/cfg.pdf", bbox_inches="tight")
|
||||
42
tools/figures/feat_vis.py
Normal file
42
tools/figures/feat_vis.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
|
||||
states = states.permute(1, 2, 0, 3)
|
||||
print(states.shape)
|
||||
states = states.view(-1, 49, 1152)
|
||||
states = torch.nn.functional.normalize(states, dim=-1)
|
||||
sim = torch.bmm(states, states.transpose(1, 2))
|
||||
mean_sim = torch.mean(sim, dim=0, keepdim=False)
|
||||
|
||||
mean_sim = mean_sim.numpy()
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
timesteps = np.linspace(0, 1, 5)
|
||||
# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
|
||||
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
|
||||
plt.imshow(mean_sim, cmap="inferno")
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
# plt.show()
|
||||
plt.colorbar()
|
||||
plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states, states)
|
||||
|
||||
|
||||
# for i in range(49):
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
|
||||
# cos_sim = cos_sim.min()
|
||||
# print(cos_sim)
|
||||
# state = torch.max(states, dim=-1)[1]
|
||||
# # state = torch.softmax(state, dim=-1)
|
||||
# state = state.view(-1, 16, 16)
|
||||
#
|
||||
# state = state.numpy()
|
||||
#
|
||||
# import numpy as np
|
||||
# import matplotlib.pyplot as plt
|
||||
# for i in range(0, 49):
|
||||
# print(i)
|
||||
# plt.imshow(state[i])
|
||||
# plt.savefig("./output2/{}.png".format(i))
|
||||
63
tools/figures/large++.py
Normal file
63
tools/figures/large++.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
|
||||
"12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
|
||||
"16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
|
||||
"20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
|
||||
"12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
|
||||
"16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
|
||||
"20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
|
||||
"12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
|
||||
"16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
|
||||
"20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
|
||||
"12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
|
||||
"16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
|
||||
"20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
|
||||
"12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
|
||||
"16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
|
||||
"20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
|
||||
plt.legend(fontsize="14")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
18
tools/figures/log_snr.py
Normal file
18
tools/figures/log_snr.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
t = np.linspace(0.001, 0.999, 100)
|
||||
def snr(t):
|
||||
return np.log((1-t)/t)
|
||||
def pds(t):
|
||||
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
|
||||
print(pds(t))
|
||||
plt.figure(figsize=(16, 4))
|
||||
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
|
||||
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
|
||||
plt.ylabel("log-SNR", weight="bold")
|
||||
plt.xlabel("Timesteps", weight="bold")
|
||||
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
plt.gca().invert_xaxis()
|
||||
plt.show()
|
||||
# plt.savefig("output/logsnr.pdf", bbox_inches='tight')
|
||||
BIN
tools/figures/output/base++_FID.pdf
Normal file
BIN
tools/figures/output/base++_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_FID50K.pdf
Normal file
BIN
tools/figures/output/base++_FID50K.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_InceptionScore.pdf
Normal file
BIN
tools/figures/output/base++_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_Precision.pdf
Normal file
BIN
tools/figures/output/base++_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_Recall.pdf
Normal file
BIN
tools/figures/output/base++_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_FID.pdf
Normal file
BIN
tools/figures/output/base_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_InceptionScore.pdf
Normal file
BIN
tools/figures/output/base_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_Precision.pdf
Normal file
BIN
tools/figures/output/base_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_Recall.pdf
Normal file
BIN
tools/figures/output/base_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/cfg.pdf
Normal file
BIN
tools/figures/output/cfg.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_FID.pdf
Normal file
BIN
tools/figures/output/large++_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_FID50K.pdf
Normal file
BIN
tools/figures/output/large++_FID50K.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_InceptionScore.pdf
Normal file
BIN
tools/figures/output/large++_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_Precision.pdf
Normal file
BIN
tools/figures/output/large++_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_Recall.pdf
Normal file
BIN
tools/figures/output/large++_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/logsnr.pdf
Normal file
BIN
tools/figures/output/logsnr.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/mean_sim.png
Normal file
BIN
tools/figures/output/mean_sim.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
BIN
tools/figures/output/sota.pdf
Normal file
BIN
tools/figures/output/sota.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/timeshift.pdf
Normal file
BIN
tools/figures/output/timeshift.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/timeshift_fid.pdf
Normal file
BIN
tools/figures/output/timeshift_fid.pdf
Normal file
Binary file not shown.
95
tools/figures/sota.py
Normal file
95
tools/figures/sota.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
data = {
|
||||
"SiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.06,
|
||||
"color": "#ff99c8"
|
||||
},
|
||||
"DiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.27,
|
||||
"color": "#fcf6bd"
|
||||
},
|
||||
"REPA-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 800,
|
||||
"FID": 1.42,
|
||||
"color": "#d0f4de"
|
||||
},
|
||||
# "MAR-H" : {
|
||||
# "size": 973,
|
||||
# "epochs": 800,
|
||||
# "FID": 1.55,
|
||||
# },
|
||||
"MDTv2" : {
|
||||
"size": 675,
|
||||
"epochs": 920,
|
||||
"FID": 1.58,
|
||||
"color": "#e4c1f9"
|
||||
},
|
||||
# "VAVAE+LightningDiT" : {
|
||||
# "size": 675,
|
||||
# "epochs": [64, 800],
|
||||
# "FID": [2.11, 1.35],
|
||||
# },
|
||||
"DDT-XL/2": {
|
||||
"size": 675,
|
||||
"epochs": [80, 256],
|
||||
"FID": [1.52, 1.31],
|
||||
"color": "#38a3a5"
|
||||
},
|
||||
"DDT-L/2": {
|
||||
"size": 400,
|
||||
"epochs": 80,
|
||||
"FID": 1.64,
|
||||
"color": "#5bc0be"
|
||||
},
|
||||
}
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
for k, spec in data.items():
|
||||
plt.scatter(
|
||||
# spec["size"],
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
label=k,
|
||||
marker="o",
|
||||
s=spec["size"],
|
||||
color=spec["color"],
|
||||
)
|
||||
x = spec["epochs"]
|
||||
y = spec["FID"]
|
||||
if isinstance(spec["FID"], list):
|
||||
x = spec["epochs"][-1]
|
||||
y = spec["FID"][-1]
|
||||
plt.plot(
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
color=spec["color"],
|
||||
linestyle="dotted",
|
||||
linewidth=4
|
||||
)
|
||||
# plt.annotate("",
|
||||
# xytext=(spec["epochs"][0], spec["FID"][0]),
|
||||
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
|
||||
plt.text(x+80, y-0.05, k, fontsize=13)
|
||||
|
||||
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
|
||||
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
|
||||
|
||||
plt.annotate("",
|
||||
xy=(700, 1.42), xytext=(200, 1.42),
|
||||
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
|
||||
)
|
||||
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.gca().set_xlim(0, 1800)
|
||||
plt.gca().set_ylim(1.15, 2.5)
|
||||
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
|
||||
plt.xlabel("Training Epochs", weight="bold")
|
||||
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
|
||||
plt.savefig("output/sota.pdf", bbox_inches="tight")
|
||||
26
tools/figures/timeshift.py
Normal file
26
tools/figures/timeshift.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
t = np.linspace(0, 1, 100)
|
||||
shifts = [1.0, 1.5, 2, 3]
|
||||
for i , shift in enumerate(shifts):
|
||||
plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
|
||||
|
||||
# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.title("Respaced timesteps with various shift value", weight="bold")
|
||||
# plt.gca().set_xlim(0, 1.0)
|
||||
# plt.gca().set_ylim(0, 1.0)
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
|
||||
plt.ylabel("Respaced Timesteps", weight="bold")
|
||||
plt.xlabel("Uniform Timesteps", weight="bold")
|
||||
plt.legend(loc="upper left", fontsize="12")
|
||||
plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)
|
||||
29
tools/figures/timeshift_fid.py
Normal file
29
tools/figures/timeshift_fid.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
data = {
|
||||
"shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
|
||||
"shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
|
||||
"shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
|
||||
"shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
|
||||
}
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
steps = [5, 6, 7, 8, 9, 10, 12]
|
||||
for i ,(k, v)in enumerate(data.items()):
|
||||
plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
|
||||
|
||||
# plt.title("FID50K of different steps of different timeshift", weight="bold")
|
||||
plt.ylabel("FID50K", weight="bold")
|
||||
plt.xlabel("Num of inference steps", weight="bold")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
# plt.legend()
|
||||
# plt.legend()
|
||||
plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)
|
||||
21
tools/fm_images.py
Normal file
21
tools/fm_images.py
Normal file
@@ -0,0 +1,21 @@
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
PATH = "/mnt/bn/wangshuai6/neural_sampling_workdirs/expbaseline_adam2_timeshift1.5"
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
PATH = pathlib.Path(PATH)
|
||||
images = []
|
||||
|
||||
# find images
|
||||
for ext in IMG_EXTENSIONS:
|
||||
images.extend(PATH.rglob(ext))
|
||||
|
||||
for image in images:
|
||||
os.system(f"rm -f {image}")
|
||||
|
||||
23
tools/mm.py
Normal file
23
tools/mm.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import time
|
||||
import torch.nn as nn
|
||||
import accelerate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = nn.Linear(512, 512)
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
accelerator = accelerate.Accelerator()
|
||||
model = accelerator.prepare_model(model)
|
||||
model.to(accelerator.device)
|
||||
data = torch.randn(1024, 512).to(accelerator.device)
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
accelerator.wait_for_everyone()
|
||||
if torch.cuda.utilization() < 1.5:
|
||||
with torch.no_grad():
|
||||
model(data)
|
||||
else:
|
||||
time.sleep(1)
|
||||
# print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}")
|
||||
20
tools/sigmoid.py
Normal file
20
tools/sigmoid.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def lw(x, b=0):
|
||||
x = np.clip(x, a_min=0.001, a_max=0.999)
|
||||
snr = x/(1-x)
|
||||
logsnr = np.log(snr)
|
||||
# print(logsnr)
|
||||
# return logsnr
|
||||
weight = 1 / (1 + np.exp(-logsnr - b))#*(1-x)**2
|
||||
return weight #/weight.max()
|
||||
|
||||
x = np.arange(0.2, 0.8, 0.001)
|
||||
print(1/(x*(1-x)))
|
||||
for b in [0, 1, 2, 3]:
|
||||
y = lw(x, b)
|
||||
plt.plot(x, y, label=f"b={b}")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
173
tools/vae2dino.py
Normal file
173
tools/vae2dino.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 4
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
# dino = DINOv2("dinov2_vitb14")
|
||||
dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
acc_mean = torch.zeros((768, ), device=accelerator.device)
|
||||
acc_num = 0
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
acc_num += len(images)
|
||||
feature = dino(images)
|
||||
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
|
||||
for std in stds:
|
||||
print("{:.2f},".format(std), end='')
|
||||
print()
|
||||
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
|
||||
for mean in means:
|
||||
print("{:.2f},".format(mean), end='')
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
23
tools/vis_timeshift.py
Normal file
23
tools/vis_timeshift.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
def gaussian(t):
|
||||
gs = 1+scipy.special.erf((t-t.mean())/t.std())
|
||||
|
||||
def rs2(t, s=2.0):
|
||||
factor1 = 1.0 #s/(s+(1-s)*t)**2
|
||||
factor2 = np.log(t.clip(0.001, 0.999)/(1-t).clip(0.001, 0.999))
|
||||
return factor1*factor2
|
||||
|
||||
|
||||
t = np.linspace(0, 1, 100)
|
||||
# plt.plot(t, timeshift(t, 1.0))
|
||||
respaced_t = timeshift(t, s=5)
|
||||
delats = (respaced_t[1:] - respaced_t[:-1])
|
||||
# plt.plot(t, timeshift(t, 1.5))
|
||||
plt.plot(rs2(t))
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user