submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

117
tools/cache_imlatent3.py Normal file
View File

@@ -0,0 +1,117 @@
from diffusers import AutoencoderKL
import torch
from typing import Callable
from torchvision.datasets import ImageFolder, ImageNet
import cv2
import os
import torch
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from PIL import Image
import pathlib
import torch
import random
from torchvision.io.image import read_image
import torchvision.transforms as tvtf
from torch.utils.data import Dataset
from torchvision.datasets import ImageNet
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
if __name__ == "__main__":
writer_pool = ThreadPoolExecutor(8)
for split in ['train']:
train = split == 'train'
transforms = tvtf.Compose([
CenterCrop(256),
# tvtf.RandomHorizontalFlip(p=1),
tvtf.ToTensor(),
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
# dataset = ImageNet(root='/tmp', split="train", transform=transforms, )
B = 256
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16)
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
from accelerate import Accelerator
accelerator = Accelerator()
vae, dataloader = accelerator.prepare(vae, dataloader)
rank = accelerator.process_index
with torch.no_grad():
for i, (image, label, path_list) in enumerate(dataloader):
# if i >= 128: break
new_path_list = []
for p in path_list:
p = p + ".pt"
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
"/mnt/bn/wangshuai6/data/ImageNet/train_256latent")
new_path_list.append(p)
image = image.to("cuda")
distribution = vae.module.encode(image).latent_dist
mean = distribution.mean
logvar = distribution.logvar
for j in range(B):
out = dict(
mean=mean[j].cpu(),
logvar=logvar[j].cpu(),
)
writer_pool.submit(save, out, new_path_list[j])
writer_pool.shutdown(wait=True)
accelerator.wait_for_everyone()

123
tools/cache_imlatent4.py Normal file
View File

@@ -0,0 +1,123 @@
from diffusers import AutoencoderKL
import torch
from typing import Callable
from torchvision.datasets import ImageFolder, ImageNet
import cv2
import os
import torch
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from PIL import Image
import pathlib
import torch
import random
from torchvision.io.image import read_image
import torchvision.transforms as tvtf
from torch.utils.data import Dataset
from torchvision.datasets import ImageNet
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
if __name__ == "__main__":
writer_pool = ThreadPoolExecutor(8)
for split in ['train']:
train = split == 'train'
transforms = tvtf.Compose([
CenterCrop(512),
# tvtf.RandomHorizontalFlip(p=1),
tvtf.ToTensor(),
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
B = 8
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16)
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
vae = vae.to(torch.float16)
from accelerate import Accelerator
accelerator = Accelerator()
vae, dataloader = accelerator.prepare(vae, dataloader)
rank = accelerator.process_index
with torch.no_grad():
for i, (image, label, path_list) in enumerate(dataloader):
print(i/len(dataloader))
flag = False
new_path_list = []
for p in path_list:
p = p + ".pt"
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
"/mnt/bn/wangshuai6/data/ImageNet/train_512_latent")
new_path_list.append(p)
if not os.path.exists(p):
print(p)
flag = True
if flag:
image = image.to("cuda")
image = image.to(torch.float16)
distribution = vae.module.encode(image).latent_dist
mean = distribution.mean
logvar = distribution.logvar
for j in range(len(path_list)):
out = dict(
mean=mean[j].cpu(),
logvar=logvar[j].cpu(),
)
writer_pool.submit(save, out, new_path_list[j])
writer_pool.shutdown(wait=True)
accelerator.wait_for_everyone()

43
tools/cat_images.py Normal file
View File

@@ -0,0 +1,43 @@
import cv2
import numpy as np
import os
import pathlib
import argparse
def group_images(path_list):
sorted(path_list)
class_id_dict = {}
for path in path_list:
class_id = str(path.name).split('_')[0]
if class_id not in class_id_dict:
class_id_dict[class_id] = []
class_id_dict[class_id].append(path)
return class_id_dict
def cat_images(path_list):
imgs = []
for path in path_list:
img = cv2.imread(str(path))
os.remove(path)
imgs.append(img)
row_cat_images = []
row_length = int(len(imgs)**0.5)
for i in range(len(imgs)//row_length):
row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
cat_image = np.concatenate(row_cat_images, axis=0)
return cat_image
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--src_dir', type=str, default=None)
args = parser.parse_args()
src_dir = args.src_dir
path_list = list(pathlib.Path(src_dir).glob('*.png'))
class_id_dict = group_images(path_list)
for class_id, path_list in class_id_dict.items():
cat_image = cat_images(path_list)
cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
# cat_path = "cat_{}.png".format(class_id)
cv2.imwrite(cat_path, cat_image)

353
tools/classifer_training.py Normal file
View File

@@ -0,0 +1,353 @@
import torch
import torch.nn as nn
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
import copy
# NORMALIZE_DATA = dict(
# dinov2_vits14a = dict(
# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79,
# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49,
# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67,
# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12,
# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43,
# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61,
# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23,
# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57,
# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65,
# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13,
# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27,
# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94,
# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07,
# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56,
# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50,
# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29,
# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78,
# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03,
# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02,
# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25,
# ],
# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92,
# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90,
# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03,
# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21,
# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04,
# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68,
# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71,
# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88,
# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86,
# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44,
# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15,
# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49,
# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90,
# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87,
# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81,
# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89,
# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97,
# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96,
# 2.62,3.23,2.13,2.29,2.24,2.72
# ]
# ),
# dinov2_vitb14a = dict(
# mean=[
# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82,
# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17,
# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31,
# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15,
# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03,
# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54,
# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43,
# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26
# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05,
# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19,
# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36,
# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00,
# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14,
# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44,
# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24,
# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25,
# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12,
# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36,
# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10,
# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22,
# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28,
# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23,
# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62,
# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26,
# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12,
# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07,
# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57,
# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62,
# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10,
# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51,
# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22,
# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36,
# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10,
# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02,
# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42,
# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12,
# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60,
# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35,
# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64,
# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31
# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00,
# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73,
# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05,
# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35,
# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05,
# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54,
# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03,
# ],
# std=[
# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40,
# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39,
# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56,
# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73,
# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49,
# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42,
# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47,
# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49,
# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05,
# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46,
# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67,
# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91,
# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48,
# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76,
# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37,
# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52,
# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42,
# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56,
# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43,
# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47,
# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54,
# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49,
# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59,
# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52,
# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55,
# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55,
# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38,
# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38,
# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53,
# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49,
# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41,
# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43,
# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68,
# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47,
# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43,
# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54,
# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44,
# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69,
# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50,
# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33,
# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58,
# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47,
# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29
# ]
# )
# )
class DINOv2a(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2a, self).__init__()
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
# self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False)
# self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False)
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
# feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1)
feature = feature.transpose(1, 2)
feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2)
return feature
from torchvision.datasets import ImageFolder, ImageNet
import os
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as tvtf
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
import math
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[..., None].float() * freqs[None, ...]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class Classifer(nn.Module):
def __init__(self, in_channels=192, hidden_size=256, num_classes=1000):
super(Classifer, self).__init__()
self.in_channels = in_channels
self.feature_x = nn.Sequential(
nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0),
nn.AdaptiveAvgPool2d(1),
)
def forward(self, xt):
xt = xt[:, :self.in_channels]
score = self.feature_x(xt).squeeze(-1).squeeze(-1)
# score = (feature_xt).clamp(-5, 5)
score = torch.softmax(score, dim=1)
return score
if __name__ == "__main__":
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
os.environ["TORCH_HOME"] = torch_hub_dir
torch.hub.set_dir(torch_hub_dir)
transforms = tvtf.Compose([
CenterCrop(256),
tvtf.ToTensor(),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
B = 64
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True)
dino = DINOv2a("dinov2_vitb14")
from accelerate import Accelerator
accelerator = Accelerator()
rank = accelerator.process_index
classifer = Classifer(in_channels=32)
classifer.train()
optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001)
dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer)
# fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance"
# fake_file_names = os.listdir(fake_file_dir)
for epoch in range(100):
for i, (true_images, true_labels, path_list) in enumerate(dataloader):
batch_size = true_images.shape[0]
true_labels = true_labels.to(accelerator.device)
true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000)
with torch.no_grad():
true_dino_feature = dino(true_images)
# t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device)
# true_x_t = t * true_dino_feature + (1-t) * noise
true_x_t = true_dino_feature
true_score = classifer(true_x_t)
# ind = i % len(fake_file_names)
# fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind]))
# import pdb; pdb.set_trace()
# ind = torch.randint(0, 50, size=(4,))
# fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :]
# fake_labels = fake_file['condition'].repeat(4)
# fake_score = classifer(fake_x_t)
loss_true = -torch.log(true_score)*true_labels
loss = loss_true.sum()/batch_size
loss.backward()
optimizer.step()
optimizer.zero_grad()
acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size
if accelerator.is_main_process:
print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item()))
if accelerator.is_main_process:
torch.save(classifer.state_dict(), f'{epoch}.pth')
accelerator.wait_for_everyone()

4
tools/debug_env.sh Normal file
View File

@@ -0,0 +1,4 @@
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat
pip3 install -r requirements.txt
git branch --set-upstream-to=origin/master master
git pull

173
tools/dino_scale.py Normal file
View File

@@ -0,0 +1,173 @@
import torch
import torch.nn as nn
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
import copy
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
return feature
class MAE(nn.Module):
def __init__(self, model_id, weight_path:str):
super(MAE, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
from diffusers import AutoencoderKL
from torchvision.datasets import ImageFolder, ImageNet
import cv2
import os
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from PIL import Image
import torch
import torchvision.transforms as tvtf
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
if __name__ == "__main__":
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
os.environ["TORCH_HOME"] = torch_hub_dir
torch.hub.set_dir(torch_hub_dir)
for split in ['train']:
train = split == 'train'
transforms = tvtf.Compose([
CenterCrop(256),
tvtf.ToTensor(),
# tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
B = 4096
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
dino = DINOv2("dinov2_vitb14")
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
from accelerate import Accelerator
accelerator = Accelerator()
dino, dataloader = accelerator.prepare(dino, dataloader)
rank = accelerator.process_index
acc_mean = torch.zeros((768, ), device=accelerator.device)
acc_num = 0
with torch.no_grad():
for i, (images, labels, path_list) in enumerate(dataloader):
acc_num += len(images)
feature = dino(images)
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
for std in stds:
print("{:.2f},".format(std), end='')
print()
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
for mean in means:
print("{:.2f},".format(mean), end='')
break
accelerator.wait_for_everyone()

168
tools/dino_scale2.py Normal file
View File

@@ -0,0 +1,168 @@
import torch
import torch.nn as nn
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
import copy
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
return feature
class MAE(nn.Module):
def __init__(self, model_id, weight_path:str):
super(MAE, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
from diffusers import AutoencoderKL
from torchvision.datasets import ImageFolder, ImageNet
import cv2
import os
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from PIL import Image
import torch
import torchvision.transforms as tvtf
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
if __name__ == "__main__":
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
os.environ["TORCH_HOME"] = torch_hub_dir
torch.hub.set_dir(torch_hub_dir)
for split in ['train']:
train = split == 'train'
transforms = tvtf.Compose([
CenterCrop(256),
tvtf.ToTensor(),
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
B = 2048
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
dino = DINOv2("dinov2_vitb14")
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
from accelerate import Accelerator
accelerator = Accelerator()
dino, dataloader = accelerator.prepare(dino, dataloader)
rank = accelerator.process_index
with torch.no_grad():
for i, (images, labels, path_list) in enumerate(dataloader):
feature = dino(images)
b, c, h, w = feature.shape
feature = feature.view(b, c, h*w).transpose(1, 2)
feature = feature.reshape(-1, c)
U, S, V = torch.pca_lowrank(feature, 64, )
import pdb; pdb.set_trace()
feature = torch.matmul(feature, V)
break
accelerator.wait_for_everyone()

64
tools/dp.py Normal file
View File

@@ -0,0 +1,64 @@
import matplotlib.pyplot as plt
print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250]))
print(len(list(range(0, 251, 5))))
exit()
plt.plot()
plt.plot()
plt.show()
exit()
import torch
num_steps = 10
num_recompute_timesteps = 4
sim = torch.randint(0, 100, (num_steps, num_steps))
sim[:5, :5] = 100
for i in range(num_steps):
sim[i, i] = 100
error_map = (100-sim).tolist()
# init
for i in range(1, num_steps):
for j in range(0, i):
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
for i in range(1, num_steps+1):
C[1][i] = error_map[i-1][0]
P[1][i] = 0
# dp
for step in range(2, num_recompute_timesteps+1):
for i in range(step, num_steps+1):
min_value = 99999
min_index = -1
for j in range(step-1, i):
value = C[step-1][j] + error_map[i-1][j]
if value < min_value:
min_value = value
min_index = j
C[step][i] = min_value
P[step][i] = min_index
# trace back
tracback_end_index = num_steps
# min_value = 99999
# for i in range(num_recompute_timesteps-1, num_steps):
# if C[-1][i] < min_value:
# min_value = C[-1][i]
# tracback_end_index = i
timesteps = [tracback_end_index, ]
for i in range(num_recompute_timesteps, 0, -1):
idx = timesteps[-1]
timesteps.append(P[i][idx])
timesteps.reverse()
print(timesteps)

64
tools/figures/base++.py Normal file
View File

@@ -0,0 +1,64 @@
import numpy as np
import matplotlib.pyplot as plt
is_data = {
"4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
"6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
"8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
"10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
}
fid_data = {
"4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
"6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
"8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
"10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
}
sfid_data = {
"4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
"6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
"8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
"10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
}
pr_data = {
"4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
"6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
"8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
"10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
}
recall_data = {
"4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
"6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
"8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
"10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
}
x = [100, 200, 300, 400]
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
metric_data = {
"FID50K" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
plt.legend(fontsize="14")
plt.xticks([100, 150, 200, 250, 300, 350, 400])
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
plt.close()

57
tools/figures/base.py Normal file
View File

@@ -0,0 +1,57 @@
import numpy as np
import matplotlib.pyplot as plt
fid_data = {
"4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
"6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
"8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
}
sfid_data = {
"4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
"6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
"8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
}
is_data = {
"4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
"6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
"8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
}
pr_data = {
"4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
"6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
"8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
}
recall_data = {
"4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
"6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
"8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
}
x = [100, 200, 300, 400]
colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
metric_data = {
"FID" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
plt.legend()
plt.xticks(x)
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
plt.close()

32
tools/figures/cfg.py Normal file
View File

@@ -0,0 +1,32 @@
import numpy as np
import matplotlib.pyplot as plt
cfg_data = {
"[0, 1]":{
"cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
"FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
},
"[0.2, 1]":{
"cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
"FID": [5.87, 4.44, 3.96, 4.01, 4.26]
},
"[0.3, 1]":{
"cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
"FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
},
"[0.35, 1]":{
"cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
"FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
}
}
colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
for i, (name, data) in enumerate(cfg_data.items()):
plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
plt.title("Classifer-free guidance with intervals", weight="bold")
plt.ylabel("FID10K", weight="bold")
plt.xlabel("CFG values", weight="bold")
plt.legend()
plt.savefig("./output/cfg.pdf", bbox_inches="tight")

42
tools/figures/feat_vis.py Normal file
View File

@@ -0,0 +1,42 @@
import torch
states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
states = states.permute(1, 2, 0, 3)
print(states.shape)
states = states.view(-1, 49, 1152)
states = torch.nn.functional.normalize(states, dim=-1)
sim = torch.bmm(states, states.transpose(1, 2))
mean_sim = torch.mean(sim, dim=0, keepdim=False)
mean_sim = mean_sim.numpy()
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
timesteps = np.linspace(0, 1, 5)
# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
plt.imshow(mean_sim, cmap="inferno")
plt.xticks([])
plt.yticks([])
# plt.show()
plt.colorbar()
plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
# cos_sim = torch.nn.functional.cosine_similarity(states, states)
# for i in range(49):
# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
# cos_sim = cos_sim.min()
# print(cos_sim)
# state = torch.max(states, dim=-1)[1]
# # state = torch.softmax(state, dim=-1)
# state = state.view(-1, 16, 16)
#
# state = state.numpy()
#
# import numpy as np
# import matplotlib.pyplot as plt
# for i in range(0, 49):
# print(i)
# plt.imshow(state[i])
# plt.savefig("./output2/{}.png".format(i))

63
tools/figures/large++.py Normal file
View File

@@ -0,0 +1,63 @@
import numpy as np
import matplotlib.pyplot as plt
is_data = {
"10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
"12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
"16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
"20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
}
fid_data = {
"10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
"12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
"16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
"20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
}
sfid_data = {
"10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
"12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
"16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
"20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
}
pr_data = {
"10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
"12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
"16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
"20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
}
recall_data = {
"10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
"12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
"16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
"20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
}
x = [100, 200, 300, 400]
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
metric_data = {
"FID50K" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
plt.legend(fontsize="14")
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.xticks([100, 150, 200, 250, 300, 350, 400])
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
plt.close()

18
tools/figures/log_snr.py Normal file
View File

@@ -0,0 +1,18 @@
import numpy as np
import matplotlib.pyplot as plt
t = np.linspace(0.001, 0.999, 100)
def snr(t):
return np.log((1-t)/t)
def pds(t):
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
print(pds(t))
plt.figure(figsize=(16, 4))
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
plt.ylabel("log-SNR", weight="bold")
plt.xlabel("Timesteps", weight="bold")
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
plt.gca().invert_xaxis()
plt.show()
# plt.savefig("output/logsnr.pdf", bbox_inches='tight')

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

95
tools/figures/sota.py Normal file
View File

@@ -0,0 +1,95 @@
import numpy as np
import matplotlib.pyplot as plt
data = {
"SiT-XL/2" : {
"size": 675,
"epochs": 1400,
"FID": 2.06,
"color": "#ff99c8"
},
"DiT-XL/2" : {
"size": 675,
"epochs": 1400,
"FID": 2.27,
"color": "#fcf6bd"
},
"REPA-XL/2" : {
"size": 675,
"epochs": 800,
"FID": 1.42,
"color": "#d0f4de"
},
# "MAR-H" : {
# "size": 973,
# "epochs": 800,
# "FID": 1.55,
# },
"MDTv2" : {
"size": 675,
"epochs": 920,
"FID": 1.58,
"color": "#e4c1f9"
},
# "VAVAE+LightningDiT" : {
# "size": 675,
# "epochs": [64, 800],
# "FID": [2.11, 1.35],
# },
"DDT-XL/2": {
"size": 675,
"epochs": [80, 256],
"FID": [1.52, 1.31],
"color": "#38a3a5"
},
"DDT-L/2": {
"size": 400,
"epochs": 80,
"FID": 1.64,
"color": "#5bc0be"
},
}
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for k, spec in data.items():
plt.scatter(
# spec["size"],
spec["epochs"],
spec["FID"],
label=k,
marker="o",
s=spec["size"],
color=spec["color"],
)
x = spec["epochs"]
y = spec["FID"]
if isinstance(spec["FID"], list):
x = spec["epochs"][-1]
y = spec["FID"][-1]
plt.plot(
spec["epochs"],
spec["FID"],
color=spec["color"],
linestyle="dotted",
linewidth=4
)
# plt.annotate("",
# xytext=(spec["epochs"][0], spec["FID"][0]),
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
plt.text(x+80, y-0.05, k, fontsize=13)
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
plt.annotate("",
xy=(700, 1.42), xytext=(200, 1.42),
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
)
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.gca().set_xlim(0, 1800)
plt.gca().set_ylim(1.15, 2.5)
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
plt.xlabel("Training Epochs", weight="bold")
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
plt.savefig("output/sota.pdf", bbox_inches="tight")

View File

@@ -0,0 +1,26 @@
import scipy
import numpy as np
import matplotlib.pyplot as plt
def timeshift(t, s=1.0):
return t/(t+(1-t)*s)
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
t = np.linspace(0, 1, 100)
shifts = [1.0, 1.5, 2, 3]
for i , shift in enumerate(shifts):
plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
# plt.title("Respaced timesteps with various shift value", weight="bold")
# plt.gca().set_xlim(0, 1.0)
# plt.gca().set_ylim(0, 1.0)
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.ylabel("Respaced Timesteps", weight="bold")
plt.xlabel("Uniform Timesteps", weight="bold")
plt.legend(loc="upper left", fontsize="12")
plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)

View File

@@ -0,0 +1,29 @@
import scipy
import numpy as np
import matplotlib.pyplot as plt
def timeshift(t, s=1.0):
return t/(t+(1-t)*s)
data = {
"shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
"shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
"shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
"shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
}
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
steps = [5, 6, 7, 8, 9, 10, 12]
for i ,(k, v)in enumerate(data.items()):
plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
# plt.title("FID50K of different steps of different timeshift", weight="bold")
plt.ylabel("FID50K", weight="bold")
plt.xlabel("Num of inference steps", weight="bold")
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
# plt.legend()
# plt.legend()
plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)

21
tools/fm_images.py Normal file
View File

@@ -0,0 +1,21 @@
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
PATH = "/mnt/bn/wangshuai6/neural_sampling_workdirs/expbaseline_adam2_timeshift1.5"
import os
import pathlib
PATH = pathlib.Path(PATH)
images = []
# find images
for ext in IMG_EXTENSIONS:
images.extend(PATH.rglob(ext))
for image in images:
os.system(f"rm -f {image}")

23
tools/mm.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
import time
import torch.nn as nn
import accelerate
if __name__ == "__main__":
model = nn.Linear(512, 512)
for p in model.parameters():
p.requires_grad = False
accelerator = accelerate.Accelerator()
model = accelerator.prepare_model(model)
model.to(accelerator.device)
data = torch.randn(1024, 512).to(accelerator.device)
while True:
time.sleep(0.01)
accelerator.wait_for_everyone()
if torch.cuda.utilization() < 1.5:
with torch.no_grad():
model(data)
else:
time.sleep(1)
# print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}")

20
tools/sigmoid.py Normal file
View File

@@ -0,0 +1,20 @@
import numpy as np
import matplotlib.pyplot as plt
def lw(x, b=0):
x = np.clip(x, a_min=0.001, a_max=0.999)
snr = x/(1-x)
logsnr = np.log(snr)
# print(logsnr)
# return logsnr
weight = 1 / (1 + np.exp(-logsnr - b))#*(1-x)**2
return weight #/weight.max()
x = np.arange(0.2, 0.8, 0.001)
print(1/(x*(1-x)))
for b in [0, 1, 2, 3]:
y = lw(x, b)
plt.plot(x, y, label=f"b={b}")
plt.legend()
plt.show()

173
tools/vae2dino.py Normal file
View File

@@ -0,0 +1,173 @@
import torch
import torch.nn as nn
import timm
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchvision.transforms import Normalize
import copy
class DINOv2(nn.Module):
def __init__(self, weight_path:str):
super(DINOv2, self).__init__()
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.precomputed_pos_embed = dict()
def fetch_pos(self, h, w):
key = (h, w)
if key in self.precomputed_pos_embed:
return self.precomputed_pos_embed[key]
value = timm.layers.pos_embed.resample_abs_pos_embed(
self.pos_embed.data, [h, w],
)
self.precomputed_pos_embed[key] = value
return value
def forward(self, x):
b, c, h, w = x.shape
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
self.encoder.pos_embed.data = pos_embed_data
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
return feature
class MAE(nn.Module):
def __init__(self, model_id, weight_path:str):
super(MAE, self).__init__()
if os.path.isdir(weight_path):
weight_path = os.path.join(weight_path, "pytorch_model.bin")
self.encoder = timm.create_model(
model_id,
checkpoint_path=weight_path,
num_classes=0,
)
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
self.encoder.head = torch.nn.Identity()
self.patch_size = self.encoder.patch_embed.patch_size
self.shifts = nn.Parameter(torch.tensor([0.0
]), requires_grad=False)
self.scales = nn.Parameter(torch.tensor([1.0
]), requires_grad=False)
def forward(self, x):
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
b, c, h, w = x.shape
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
feature = feature.transpose(1, 2)
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
return feature
from diffusers import AutoencoderKL
from torchvision.datasets import ImageFolder, ImageNet
import cv2
import os
import numpy as np
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from PIL import Image
import torch
import torchvision.transforms as tvtf
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
class NewImageFolder(ImageFolder):
def __getitem__(self, item):
path, target = self.samples[item]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target, path
import time
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
def save(obj, path):
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.makedirs(dirname)
torch.save(obj, f'{path}')
if __name__ == "__main__":
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
os.environ["TORCH_HOME"] = torch_hub_dir
torch.hub.set_dir(torch_hub_dir)
for split in ['train']:
train = split == 'train'
transforms = tvtf.Compose([
CenterCrop(256),
tvtf.ToTensor(),
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
B = 4
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
# dino = DINOv2("dinov2_vitb14")
dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
from accelerate import Accelerator
accelerator = Accelerator()
dino, dataloader = accelerator.prepare(dino, dataloader)
rank = accelerator.process_index
acc_mean = torch.zeros((768, ), device=accelerator.device)
acc_num = 0
with torch.no_grad():
for i, (images, labels, path_list) in enumerate(dataloader):
acc_num += len(images)
feature = dino(images)
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
for std in stds:
print("{:.2f},".format(std), end='')
print()
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
for mean in means:
print("{:.2f},".format(mean), end='')
break
accelerator.wait_for_everyone()

23
tools/vis_timeshift.py Normal file
View File

@@ -0,0 +1,23 @@
import scipy
import numpy as np
import matplotlib.pyplot as plt
def timeshift(t, s=1.0):
return t/(t+(1-t)*s)
def gaussian(t):
gs = 1+scipy.special.erf((t-t.mean())/t.std())
def rs2(t, s=2.0):
factor1 = 1.0 #s/(s+(1-s)*t)**2
factor2 = np.log(t.clip(0.001, 0.999)/(1-t).clip(0.001, 0.999))
return factor1*factor2
t = np.linspace(0, 1, 100)
# plt.plot(t, timeshift(t, 1.0))
respaced_t = timeshift(t, s=5)
delats = (respaced_t[1:] - respaced_t[:-1])
# plt.plot(t, timeshift(t, 1.5))
plt.plot(rs2(t))
plt.show()