submit code
This commit is contained in:
@@ -0,0 +1,117 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
import torch
|
||||
from typing import Callable
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
writer_pool = ThreadPoolExecutor(8)
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
# tvtf.RandomHorizontalFlip(p=1),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
# dataset = ImageNet(root='/tmp', split="train", transform=transforms, )
|
||||
B = 256
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=32, num_workers=16)
|
||||
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
vae, dataloader = accelerator.prepare(vae, dataloader)
|
||||
rank = accelerator.process_index
|
||||
with torch.no_grad():
|
||||
for i, (image, label, path_list) in enumerate(dataloader):
|
||||
# if i >= 128: break
|
||||
new_path_list = []
|
||||
for p in path_list:
|
||||
p = p + ".pt"
|
||||
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
|
||||
"/mnt/bn/wangshuai6/data/ImageNet/train_256latent")
|
||||
new_path_list.append(p)
|
||||
|
||||
image = image.to("cuda")
|
||||
distribution = vae.module.encode(image).latent_dist
|
||||
mean = distribution.mean
|
||||
logvar = distribution.logvar
|
||||
for j in range(B):
|
||||
out = dict(
|
||||
mean=mean[j].cpu(),
|
||||
logvar=logvar[j].cpu(),
|
||||
)
|
||||
writer_pool.submit(save, out, new_path_list[j])
|
||||
writer_pool.shutdown(wait=True)
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -0,0 +1,123 @@
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
import torch
|
||||
from typing import Callable
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import pathlib
|
||||
|
||||
import torch
|
||||
import random
|
||||
from torchvision.io.image import read_image
|
||||
import torchvision.transforms as tvtf
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageNet
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
writer_pool = ThreadPoolExecutor(8)
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(512),
|
||||
# tvtf.RandomHorizontalFlip(p=1),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 8
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=False, prefetch_factor=16, num_workers=16)
|
||||
vae = AutoencoderKL.from_pretrained("/mnt/bn/wangshuai6/models/sd-vae-ft-ema")#.to('cuda')
|
||||
vae = vae.to(torch.float16)
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
vae, dataloader = accelerator.prepare(vae, dataloader)
|
||||
rank = accelerator.process_index
|
||||
with torch.no_grad():
|
||||
for i, (image, label, path_list) in enumerate(dataloader):
|
||||
print(i/len(dataloader))
|
||||
flag = False
|
||||
new_path_list = []
|
||||
for p in path_list:
|
||||
p = p + ".pt"
|
||||
p = p.replace("/mnt/bn/wangshuai6/data/ImageNet/train",
|
||||
"/mnt/bn/wangshuai6/data/ImageNet/train_512_latent")
|
||||
new_path_list.append(p)
|
||||
if not os.path.exists(p):
|
||||
print(p)
|
||||
flag = True
|
||||
|
||||
if flag:
|
||||
image = image.to("cuda")
|
||||
image = image.to(torch.float16)
|
||||
distribution = vae.module.encode(image).latent_dist
|
||||
mean = distribution.mean
|
||||
logvar = distribution.logvar
|
||||
|
||||
for j in range(len(path_list)):
|
||||
out = dict(
|
||||
mean=mean[j].cpu(),
|
||||
logvar=logvar[j].cpu(),
|
||||
)
|
||||
writer_pool.submit(save, out, new_path_list[j])
|
||||
writer_pool.shutdown(wait=True)
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -0,0 +1,43 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import os
|
||||
import pathlib
|
||||
import argparse
|
||||
|
||||
def group_images(path_list):
|
||||
sorted(path_list)
|
||||
class_id_dict = {}
|
||||
for path in path_list:
|
||||
class_id = str(path.name).split('_')[0]
|
||||
if class_id not in class_id_dict:
|
||||
class_id_dict[class_id] = []
|
||||
class_id_dict[class_id].append(path)
|
||||
return class_id_dict
|
||||
|
||||
def cat_images(path_list):
|
||||
imgs = []
|
||||
for path in path_list:
|
||||
img = cv2.imread(str(path))
|
||||
os.remove(path)
|
||||
imgs.append(img)
|
||||
row_cat_images = []
|
||||
row_length = int(len(imgs)**0.5)
|
||||
for i in range(len(imgs)//row_length):
|
||||
row_cat_images.append(np.concatenate(imgs[i*row_length:(i+1)*row_length], axis=1))
|
||||
cat_image = np.concatenate(row_cat_images, axis=0)
|
||||
return cat_image
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir', type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
src_dir = args.src_dir
|
||||
path_list = list(pathlib.Path(src_dir).glob('*.png'))
|
||||
class_id_dict = group_images(path_list)
|
||||
for class_id, path_list in class_id_dict.items():
|
||||
cat_image = cat_images(path_list)
|
||||
cat_path = os.path.join(src_dir, f'cat_{class_id}.jpg')
|
||||
# cat_path = "cat_{}.png".format(class_id)
|
||||
cv2.imwrite(cat_path, cat_image)
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
# NORMALIZE_DATA = dict(
|
||||
# dinov2_vits14a = dict(
|
||||
# mean=[-0.28,0.72,-0.64,-2.31,0.54,-0.29,-1.09,0.83,0.86,1.11,0.34,0.29,-0.32,1.02,0.58,-1.27,-1.19,-0.89,0.79,
|
||||
# -0.58,0.23,-0.19,1.31,-0.34,0.02,-0.18,-0.64,0.04,1.63,-0.58,-0.89,0.09,1.09,1.12,0.32,-0.41,0.04,0.49,
|
||||
# 0.11,1.97,1.06,0.05,-1.15,0.30,0.58,-0.14,-0.26,1.32,2.04,0.50,-0.64,1.18,0.39,0.39,-1.80,0.39,-0.67,
|
||||
# 0.55,-0.35,-0.41,2.23,-1.16,-0.57,0.58,-1.29,2.07,0.18,0.62,5.72,-0.55,-0.54,0.17,-0.64,-0.78,-0.25,0.12,
|
||||
# -0.58,0.36,-2.03,-2.45,-0.22,0.36,-1.02,-0.19,-0.92,-0.26,-0.27,-0.77,-1.47,-0.64,1.76,-0.03,-0.44,1.43,
|
||||
# 1.14,0.67,1.27,1.54,0.88,-1.42,-0.44,3.32,0.21,1.22,1.17,1.15,-0.53,0.04,0.87,-0.76,0.94,-0.11,0.69,-0.61,
|
||||
# 0.64,-1.21,-0.82,0.22,-1.12,-0.03,0.68,1.05,0.57,1.13,0.03,0.05,0.42,-0.12,-0.37,-0.76,-0.56,-0.76,-0.23,
|
||||
# 1.59,0.54,0.63,-0.43,0.38,1.07,0.04,-1.87,-1.92,-0.06,0.87,-0.69,-1.09,-0.30,0.33,-0.28,0.14,2.65,-0.57,
|
||||
# -0.04,0.12,-0.49,-1.60,0.39,0.05,0.12,0.66,-0.70,-0.69,0.47,-0.67,-0.59,-1.30,-0.28,-0.52,-0.98,0.67,1.65,
|
||||
# 0.72,0.55,0.05,-0.27,1.67,0.17,-0.31,-1.73,2.04,0.49,1.08,-0.37,1.75,1.31,1.03,0.65,0.43,-0.19,0.00,-1.13,
|
||||
# -0.29,-0.38,0.09,-0.24,1.49,1.01,-0.25,-0.94,0.74,0.24,-1.06,1.58,1.08,0.76,0.64,1.34,-1.09,1.54,-0.27,
|
||||
# 0.77,0.19,-0.97,0.46,0.20,-0.60,1.48,-2.33,0.43,2.32,1.89,-0.31,-0.48,-0.54,1.52,1.33,0.95,0.12,-0.33,-0.94,
|
||||
# -0.67,0.16,1.49,-0.17,-0.42,-0.02,-0.32,0.49,-1.19,0.06,0.19,-0.79,-0.21,-0.38,-0.69,0.52,0.74,0.41,-2.07,
|
||||
# -1.01,0.85,-1.41,-0.17,1.11,0.53,1.47,0.66,-0.22,0.93,-0.69,-0.42,0.06,0.11,-0.87,1.58,-0.27,-1.57,-0.56,
|
||||
# 0.98,-0.50,0.27,0.38,-1.06,-1.77,0.20,-0.33,-0.95,-0.62,-3.44,-0.67,-0.62,-1.20,0.04,-0.02,-1.15,0.56,-0.50,
|
||||
# 0.83,-1.69,0.01,-0.42,1.15,0.22,1.55,-3.02,1.24,0.28,0.40,0.69,-0.35,2.04,0.33,0.10,-1.09,0.50,0.59,1.29,
|
||||
# 0.79,0.02,-0.02,-0.49,0.07,0.84,0.55,-0.79,-0.26,-0.06,-0.91,-1.28,0.65,0.30,1.00,-0.09,0.66,-2.51,-0.78,
|
||||
# 2.94,0.18,0.24,-0.08,0.76,0.06,0.26,-0.74,0.16,-0.72,0.17,0.21,0.98,0.67,0.14,0.05,0.48,0.54,2.05,1.21,-0.03,
|
||||
# -0.85,0.38,-0.11,-0.38,-0.86,0.49,-0.87,-0.29,0.23,0.79,0.05,-0.05,-0.07,0.22,0.03,0.85,-0.63,-0.44,0.02,
|
||||
# -0.10,-0.01,0.51,-1.84,0.11,1.06,0.00,1.10,-0.56,0.21,0.44,-0.65,-0.97,-1.03,-0.50,-0.67,-0.27,-1.25,
|
||||
# ],
|
||||
# std=[1.78,3.78,1.92,2.28,1.97,2.82,1.92,2.55,1.87,1.95,1.90,1.83,1.89,2.00,1.85,1.88,1.78,1.81,3.02,1.94,1.92,
|
||||
# 2.26,2.17,6.16,1.84,2.00,1.85,1.88,3.30,2.14,1.85,2.87,3.01,2.05,1.80,1.84,2.20,2.00,1.97,2.02,1.94,1.90,
|
||||
# 1.98,2.25,1.97,2.01,2.01,1.95,2.26,2.47,1.95,1.75,1.84,3.02,2.65,2.15,2.01,1.80,2.65,2.37,2.04,2.09,2.03,
|
||||
# 1.94,1.84,2.19,1.98,1.97,4.52,2.76,2.18,2.59,1.94,2.07,1.96,1.91,2.13,3.16,1.95,2.43,1.84,2.16,2.33,2.21,
|
||||
# 2.10,1.98,1.90,1.90,1.88,1.89,2.15,1.75,1.83,2.36,2.40,2.42,1.89,2.03,1.89,2.00,1.91,2.88,2.10,2.63,2.04,
|
||||
# 1.88,1.93,1.74,2.02,1.84,1.96,1.98,1.90,1.80,1.86,2.05,2.21,1.97,1.99,1.77,2.04,2.59,1.85,2.14,1.91,1.68,
|
||||
# 1.95,1.86,1.99,2.18,2.76,2.03,1.88,2.47,1.92,3.04,2.02,1.74,2.94,1.92,2.12,1.92,2.17,2.15,1.74,2.26,1.71,
|
||||
# 2.03,2.05,1.85,3.43,1.77,1.96,1.88,1.99,2.14,2.30,2.00,1.90,2.01,1.78,1.72,2.42,1.66,1.86,2.08,2.04,1.88,
|
||||
# 2.55,2.02,1.83,1.86,1.69,2.06,1.92,2.25,1.74,1.69,2.02,3.88,1.86,2.94,1.82,2.27,2.73,2.05,1.91,1.94,1.86,
|
||||
# 1.77,2.16,2.16,1.86,1.88,2.08,2.19,1.94,1.90,2.09,2.57,1.75,1.90,2.05,2.13,1.74,1.99,1.83,2.35,4.48,2.44,
|
||||
# 1.88,2.18,2.46,1.84,1.81,2.37,2.45,2.07,1.79,3.65,2.29,2.09,2.09,2.29,1.92,2.34,1.85,2.03,1.72,2.20,2.15,
|
||||
# 2.04,2.13,2.07,1.82,1.72,2.06,1.87,2.43,1.94,1.93,1.97,1.83,1.96,2.01,1.89,1.73,2.04,2.63,2.10,2.05,2.49,
|
||||
# 2.10,2.27,1.87,2.16,2.22,2.08,1.87,2.26,1.88,2.28,3.87,1.74,3.71,2.03,2.70,2.11,1.92,2.00,2.04,2.02,1.90,
|
||||
# 2.61,2.10,2.37,1.96,2.50,1.17,1.95,1.88,2.06,2.22,1.87,1.93,1.88,3.59,1.89,3.66,1.87,1.95,3.13,1.84,2.87,
|
||||
# 3.96,2.14,2.01,1.89,1.73,1.98,2.42,2.12,2.28,1.92,1.93,2.54,2.06,1.97,2.02,2.19,2.00,2.04,1.75,1.97,1.81,
|
||||
# 1.93,1.83,2.22,2.52,1.83,1.86,2.16,2.08,2.87,3.21,2.78,2.84,2.85,1.88,1.79,1.95,1.98,1.78,1.78,2.21,1.89,
|
||||
# 2.57,2.00,2.82,1.90,2.24,2.28,1.91,2.02,2.23,2.62,1.88,2.40,2.40,2.00,1.70,1.82,1.92,1.95,1.99,2.08,1.97,
|
||||
# 2.12,1.87,3.65,2.26,1.83,1.96,1.83,1.64,2.07,2.04,2.57,1.85,2.21,1.83,1.90,1.97,2.16,2.12,1.80,1.73,1.96,
|
||||
# 2.62,3.23,2.13,2.29,2.24,2.72
|
||||
# ]
|
||||
# ),
|
||||
# dinov2_vitb14a = dict(
|
||||
# mean=[
|
||||
# 0.23, 0.44, 0.18, -0.26, -0.08, -0.80, -0.22, -0.09, -0.85, 0.44, 0.07, -0.49, 0.39, -0.12, -0.58, -0.82,
|
||||
# -0.21, -0.28, -0.40, 0.36, -0.34, 0.08, 0.31, 0.39, -0.22, -1.23, 0.50, 0.81, -0.96, 0.60, -0.45, -0.17,
|
||||
# -0.53, 0.08, 0.10, -0.32, -0.22, -0.86, 0.01, 0.19, -0.73, -0.44, -0.57, -0.45, -0.20, -0.34, -0.63, -0.31,
|
||||
# -0.80, 0.43, -0.13, 0.18, -0.11, -0.28, -0.15, 0.11, -0.74, -0.01, -0.34, 0.18, 0.37, 0.07, -0.09, -0.42, 0.15,
|
||||
# -0.24, 0.68, -0.31, -0.09, -0.62, -0.54, 0.41, -0.42, -0.08, 0.36, -0.14, 0.44, 0.12, 0.49, 0.69, 0.03,
|
||||
# -0.24, -0.41, -0.36, -0.60, 0.86, -0.76, 0.54, -0.24, 0.57, -0.40, -0.82, 0.07, 0.05, -0.24, 0.07, 0.54,
|
||||
# 1.04, -0.29, 0.67, -0.36, -0.79, 0.11, -0.12, -0.22, -0.20, -0.46, 0.17, -0.15, -0.38, -0.11, 0.24, -0.43,
|
||||
# -0.91, 0.04, 0.32, 0.27, -0.58, -0.05, 0.50, -0.47, 0.31, -1.30, 0.07, -0.16, 0.77, 1.07, -0.44, -0.48, 0.26
|
||||
# , 0.06, -0.76, -0.27, -0.37, -1.43, -0.50, -0.38, -0.03, -0.43, 0.75, -0.01, -0.16, 0.67, 0.40, 0.33, -0.05,
|
||||
# -0.94, -0.40, 0.78, 0.29, -0.60, -0.76, 0.08, -0.08, 0.58, -0.91, -1.09, -0.42, -0.42, 0.29, 0.06, -0.19,
|
||||
# -0.75, -0.07, 0.48, -0.30, -0.44, 0.02, 0.11, 0.23, -0.76, -0.76, -0.51, 0.78, -0.58, 0.02, 0.17, -0.36,
|
||||
# -0.63, 0.48, 0.09, -0.32, -0.48, -0.09, 0.09, -0.36, 0.11, -0.17, 0.11, -0.80, -0.34, -0.52, 0.10, -0.00, 0.00,
|
||||
# -0.15, 0.91, -0.48, 0.64, -0.38, 0.28, 0.56, 0.04, -0.30, 0.14, -0.30, -0.82, 0.47, 0.57, -1.00, -0.14,
|
||||
# 0.00, 0.10, 0.01, 0.57, -0.09, -3.56, -0.22, -0.24, -0.13, 0.36, 0.30, 0.20, 0.09, 0.08, 0.66, 0.62, 0.44,
|
||||
# 0.38, 0.46, -0.27, 0.21, 0.07, -0.57, 0.93, 0.39, 0.06, -0.47, 0.34, 0.44, -0.00, -0.52, -0.35, 0.23, -0.24,
|
||||
# -0.01, -0.15, 0.11, 0.53, -0.23, 0.28, -0.22, 0.57, -0.07, 0.49, 0.74, 0.85, -0.31, -0.44, 0.22, -0.02, 0.25,
|
||||
# -0.01, -0.47, -0.23, 0.03, 0.48, -0.19, 1.55, -0.05, 0.24, 0.26, -0.25, 0.38, -0.44, -0.51, 0.34, -0.12,
|
||||
# -0.76, -0.13, 0.57, 0.01, 0.63, 0.40, 0.20, -0.33, -0.31, -0.89, 0.65, -0.46, -0.88, -0.22, 0.34, 0.36,
|
||||
# 0.95, 0.33, 0.62, -0.49, 0.40, -0.12, -0.07, -0.65, -0.05, -0.58, 0.65, 0.18, -0.81, -0.64, 0.26, -0.10,
|
||||
# -0.71, 0.47, -0.05, 0.12, -0.18, 0.77, 0.47, 0.50, 0.48, -0.45, 0.03, 0.16, 0.66, -0.42, -0.05, 0.23, -0.22,
|
||||
# -0.46, 0.25, 0.28, 0.18, -0.20, -0.14, -0.93, -0.27, -0.23, 0.15, -0.10, -0.39, -0.20, -0.05, -0.09, 0.28,
|
||||
# -0.58, -0.54, 0.09, -0.89, -0.09, 0.03, -0.86, -0.46, -0.70, 0.48, -0.59, -0.56, -0.55, -0.27, -0.50, 0.23,
|
||||
# 0.63, -1.45, -0.27, -0.04, -0.17, 0.38, -0.02, 0.28, 0.53, -0.81, -0.60, -0.07, 0.22, 0.23, 0.33, -0.62,
|
||||
# 0.09, -0.19, -0.09, -0.28, -0.13, 0.66, 0.37, -0.17, -0.52, -0.15, -0.60, 0.15, -0.25, 0.42, -0.06, 0.26,
|
||||
# 0.55, 0.72, 0.48, 0.39, -0.41, -0.76, -0.62, 0.53, 0.18, 0.35, -0.27, -0.20, -0.71, -0.55, 0.16, -0.24, -0.12,
|
||||
# 0.38, -0.53, -0.43, 0.21, -0.60, -0.24, -0.11, 1.29, 0.02, -0.05, 0.13, 0.48, 0.39, -0.43, -0.05, 0.07,
|
||||
# -0.92, 0.89, -0.21, 0.30, -0.44, 0.04, -0.30, 0.11, -0.36, -0.46, -0.20, 0.10, 0.88, -0.15, 0.28, 0.57,
|
||||
# -0.10, 0.48, 0.77, -0.12, 0.17, -0.43, -0.20, 0.22, 0.36, -0.49, -0.54, -0.07, 0.67, 0.40, -0.94, -0.62,
|
||||
# 0.46, 0.75, -0.16, -0.32, 0.30, 0.41, 0.03, -0.31, -0.17, -0.47, 0.53, 0.24, -0.77, 0.32, 0.58, -0.08, -0.71, 0.10,
|
||||
# -0.14, 0.39, 0.64, -0.08, -0.38, 0.60, 0.02, 0.61, 0.47, 0.32, 0.35, -0.01, -0.03, -0.15, -0.01, 0.51,
|
||||
# -0.52, 0.51, -0.82, 0.58, -0.13, 0.07, 0.46, -2.86, 0.36, -0.27, 0.70, 0.54, 0.31, 0.08, -0.67, 0.58, 0.22,
|
||||
# -0.40, 1.05, 0.02, 0.41, -0.66, -0.29, 0.68, 0.40, 0.53, 0.09, -0.31, -0.28, 0.20, 0.01, -0.07, -0.25, 0.36,
|
||||
# 0.10, -0.79, 0.27, -0.18, 0.18, -1.13, 0.40, -1.07, 0.84, -0.26, -0.09, -0.99, -0.55, 0.20, -0.11, -0.10,
|
||||
# 0.49, 0.49, -0.08, -0.13, 1.00, 0.48, -0.17, -0.37, -0.31, -0.24, 0.27, -0.11, 0.21, 0.01, -0.17, -0.02,
|
||||
# -0.48, 0.25, -0.44, 0.64, 0.53, -1.02, -0.20, -0.13, -0.19, 0.07, -0.17, 0.66, 1.34, -0.40, -1.09, 0.42,
|
||||
# 0.07, -0.02, 0.50, 0.32, -0.03, 0.30, -0.53, 0.19, 0.01, -0.26, -0.54, -0.04, -0.64, -0.31, 0.85, -0.12,
|
||||
# -0.07, -0.08, -0.22, 0.27, -0.50, 0.25, 0.40, -0.60, -0.18, 0.36, 0.66, -0.16, 0.91, -0.61, 0.43, 0.31, 0.23, -0.60,
|
||||
# -0.13, -0.07, -0.44, -0.03, 0.25, 0.41, 0.08, 0.89, -1.09, -0.12, -0.12, -0.09, 0.13, 0.01, -0.55, -0.35,
|
||||
# -0.44, 0.07, -0.19, 0.35, 0.99, 0.01, 0.11, -0.04, 0.50, -0.10, 0.49, 0.61, 0.23, -0.41, 0.11, -0.36, 0.64,
|
||||
# -0.97, 0.68, -0.27, 0.30, 0.85, 0.03, 1.84, -0.15, -0.05, 0.46, -0.41, -0.01, 0.03, -0.32, 0.33, 0.14, 0.31
|
||||
# , -0.18, -0.30, 0.07, 0.70, -0.64, -0.59, 0.36, 0.39, -0.33, 0.79, 0.47, 0.44, -0.05, -0.03, -0.29, -1.00,
|
||||
# -0.04, 1.25, 0.74, 0.08, -0.53, -0.65, 0.17, -0.57, -0.39, 0.34, -0.12, -0.04, -0.63, 0.27, -0.25, -0.73,
|
||||
# -4.08, -0.09, -0.64, 0.38, -0.47, -0.36, -0.34, 0.05, 0.12, 0.37, -0.43, -0.39, 0.11, -0.32, -0.81, -0.05,
|
||||
# -0.40, -0.31, 2.64, 0.14, -2.08, 0.70, -0.52, -0.55, -0.40, -0.75, -0.20, 0.42, 0.99, -0.27, 0.35, -0.35,
|
||||
# -0.46, 0.48, 0.03, 0.64, 0.56, -0.77, -0.37, 0.02, 0.02, -0.60, -0.47, -0.49, -0.19, 0.29, 0.05, 0.17, 0.05,
|
||||
# 1.01, 0.05, 0.06, -0.00, -0.64, 0.72, 1.39, -0.45, -0.46, 0.49, -0.58, 0.36, 0.01, -0.14, -0.01, -0.54,
|
||||
# -0.46, -1.21, 0.94, -1.31, 0.61, 0.63, -0.53, 0.05, 0.37, -0.18, 1.08, -0.10, -0.80, -0.38, -0.03,
|
||||
# ],
|
||||
# std=[
|
||||
# 1.48, 1.58, 1.56, 1.49, 1.57, 1.96, 1.50, 1.34, 1.46, 1.66, 1.63, 1.44, 1.48, 1.53, 1.49, 1.39, 1.45, 1.40,
|
||||
# 1.47, 1.43, 1.65, 1.69, 1.72, 1.56, 1.50, 3.06, 1.48, 1.58, 1.63, 1.41, 1.78, 1.48, 1.64, 1.41, 1.46, 1.39,
|
||||
# 1.57, 3.80, 0.16, 1.46, 1.49, 1.51, 1.55, 1.57, 1.43, 1.69, 1.50, 1.53, 1.51, 1.49, 1.42, 1.48, 1.62, 1.56,
|
||||
# 1.52, 1.39, 1.95, 1.47, 1.33, 1.42, 1.96, 1.46, 1.54, 1.47, 1.41, 1.41, 1.50, 1.53, 1.55, 2.24, 1.52, 1.73,
|
||||
# 1.54, 1.46, 1.47, 1.55, 1.56, 1.46, 1.40, 1.49, 1.42, 1.54, 1.43, 1.48, 1.41, 1.49, 1.56, 1.59, 1.40, 1.49,
|
||||
# 1.58, 2.29, 1.58, 1.35, 1.41, 1.45, 1.43, 1.51, 1.48, 1.52, 1.51, 1.52, 1.56, 1.42, 1.44, 1.45, 1.47, 1.42,
|
||||
# 1.43, 1.49, 1.54, 1.45, 1.66, 1.48, 1.35, 1.53, 1.45, 2.38, 1.38, 1.32, 1.37, 1.49, 2.00, 1.47, 1.45, 1.47,
|
||||
# 1.63, 1.49, 1.59, 2.58, 1.70, 1.52, 1.40, 1.41, 2.57, 1.61, 1.54, 1.47, 1.62, 1.54, 1.41, 1.45, 1.57, 1.49,
|
||||
# 1.42, 1.50, 1.67, 1.45, 1.47, 1.43, 1.55, 1.47, 1.53, 1.49, 1.56, 1.58, 2.03, 2.03, 1.57, 1.44, 1.46, 1.05,
|
||||
# 1.61, 1.39, 1.47, 1.41, 1.43, 1.38, 1.34, 1.42, 1.41, 1.47, 1.79, 1.44, 1.43, 1.38, 1.39, 1.44, 1.38, 1.46,
|
||||
# 1.45, 1.51, 1.52, 1.49, 5.31, 1.41, 1.45, 1.49, 1.43, 1.94, 1.38, 1.35, 1.56, 1.45, 1.37, 1.47, 1.48, 1.67,
|
||||
# 1.46, 1.50, 1.40, 1.50, 1.62, 1.48, 1.53, 1.45, 1.51, 1.50, 1.51, 1.52, 1.55, 1.42, 1.84, 1.39, 1.54, 1.42, 4.91,
|
||||
# 1.42, 1.47, 1.51, 1.57, 1.37, 1.50, 1.39, 2.40, 1.51, 1.59, 1.44, 1.42, 1.59, 1.73, 1.44, 1.53, 1.61, 1.48,
|
||||
# 1.29, 1.47, 1.39, 1.54, 1.44, 1.43, 1.55, 1.45, 1.31, 1.43, 1.44, 1.41, 1.35, 1.62, 1.49, 1.45, 1.50, 1.76,
|
||||
# 1.44, 1.80, 1.60, 1.49, 1.43, 1.47, 1.40, 1.40, 1.50, 1.42, 1.51, 1.61, 1.47, 1.45, 1.70, 2.90, 1.51, 1.37,
|
||||
# 1.50, 1.55, 1.32, 1.42, 1.76, 1.36, 1.41, 1.61, 1.44, 1.44, 1.44, 1.47, 1.48, 1.45, 1.48, 1.56, 1.58, 1.52,
|
||||
# 1.33, 1.37, 1.64, 1.47, 2.49, 1.51, 1.60, 1.58, 1.45, 1.48, 1.81, 1.38, 1.37, 1.53, 1.72, 1.49, 1.47, 1.49, 1.42,
|
||||
# 1.44, 1.43, 1.54, 1.59, 1.40, 1.57, 1.45, 1.45, 1.45, 1.55, 1.38, 1.41, 1.46, 2.13, 1.58, 1.46, 1.35, 1.56,
|
||||
# 1.47, 1.33, 1.53, 1.62, 1.47, 1.44, 1.45, 1.49, 1.82, 1.51, 1.38, 1.54, 1.38, 1.38, 1.40, 1.40, 1.46, 1.43,
|
||||
# 1.45, 1.42, 1.67, 1.37, 1.50, 1.60, 1.42, 1.46, 1.45, 3.29, 1.45, 1.50, 1.49, 1.38, 1.48, 1.52, 2.45, 1.47,
|
||||
# 1.50, 1.47, 1.48, 1.44, 1.62, 1.48, 1.52, 1.52, 1.45, 1.51, 1.71, 1.54, 1.59, 1.40, 3.29, 1.45, 1.65, 1.37, 1.54,
|
||||
# 1.49, 2.38, 1.62, 1.39, 1.38, 1.41, 1.46, 1.57, 1.38, 2.07, 1.54, 1.40, 1.64, 1.46, 1.45, 1.40, 1.57, 1.49,
|
||||
# 1.39, 1.55, 1.67, 1.54, 1.57, 1.55, 1.41, 1.37, 1.44, 1.40, 1.46, 1.59, 1.56, 1.61, 1.44, 1.35, 1.62, 1.59,
|
||||
# 1.52, 1.41, 1.44, 1.74, 1.40, 1.40, 1.89, 1.44, 1.46, 1.62, 1.43, 1.42, 1.39, 1.37, 1.43, 1.44, 1.60, 1.52,
|
||||
# 1.44, 1.41, 1.43, 1.34, 1.54, 1.46, 1.57, 1.53, 1.40, 1.41, 1.36, 1.45, 1.42, 1.37, 1.47, 1.37, 1.40, 1.55,
|
||||
# 1.48, 1.91, 1.44, 1.54, 1.49, 1.42, 1.48, 1.54, 1.49, 1.39, 1.47, 1.50, 1.43, 1.59, 1.58, 1.78, 1.49, 1.55,
|
||||
# 1.56, 1.52, 1.56, 1.49, 1.61, 1.51, 1.35, 1.46, 1.69, 1.35, 1.38, 1.48, 1.39, 1.40, 1.35, 1.45, 1.34, 1.38,
|
||||
# 1.44, 1.46, 1.45, 1.63, 1.52, 1.44, 1.39, 1.46, 1.70, 1.41, 1.49, 1.64, 1.54, 1.33, 1.45, 1.54, 1.49, 1.38,
|
||||
# 1.42, 1.75, 1.28, 1.52, 1.62, 1.47, 1.66, 1.51, 1.50, 1.51, 1.42, 1.42, 1.60, 1.24, 1.54, 1.42, 1.44, 1.34, 1.53,
|
||||
# 1.46, 1.46, 1.65, 1.56, 1.52, 2.12, 1.58, 1.44, 1.60, 1.48, 1.51, 1.41, 1.51, 1.68, 2.10, 1.50, 1.39, 1.49,
|
||||
# 1.43, 1.53, 1.46, 1.53, 1.43, 1.78, 1.32, 1.54, 1.47, 1.55, 1.58, 1.41, 1.57, 1.39, 1.36, 1.74, 1.50, 4.41,
|
||||
# 1.50, 1.45, 1.34, 1.44, 1.50, 1.50, 1.82, 1.28, 1.76, 1.38, 1.58, 1.56, 3.73, 1.48, 1.53, 1.48, 1.63, 1.43,
|
||||
# 1.57, 3.43, 1.75, 1.45, 1.45, 1.48, 1.93, 1.47, 1.47, 1.38, 1.42, 1.56, 1.66, 1.39, 1.74, 4.76, 1.53, 1.68,
|
||||
# 1.55, 1.47, 1.57, 1.53, 1.50, 1.40, 1.57, 1.48, 1.44, 1.36, 1.32, 1.71, 1.44, 1.46, 1.47, 1.54, 1.51, 1.47,
|
||||
# 1.36, 1.29, 1.44, 1.43, 1.46, 1.40, 1.64, 1.48, 1.42, 1.32, 1.52, 1.49, 3.04, 1.52, 1.38, 1.43, 1.42, 1.43,
|
||||
# 1.48, 1.49, 1.59, 1.55, 1.62, 2.04, 1.53, 1.42, 1.89, 1.43, 1.41, 3.84, 1.48, 1.51, 1.48, 1.58, 1.54, 1.54,
|
||||
# 1.54, 1.55, 1.45, 1.49, 1.46, 2.25, 1.43, 1.62, 1.66, 1.80, 1.37, 1.64, 1.49, 1.50, 1.39, 1.41, 1.41, 1.46, 1.44,
|
||||
# 1.69, 1.47, 1.56, 1.65, 1.51, 1.52, 1.43, 1.53, 1.51, 1.46, 1.62, 1.46, 1.53, 1.68, 1.61, 1.56, 1.42, 4.69,
|
||||
# 1.31, 1.48, 1.50, 1.82, 1.45, 1.54, 1.56, 1.53, 1.58, 1.59, 1.82, 1.45, 1.54, 1.58, 1.45, 1.40, 1.49, 2.50,
|
||||
# 1.52, 2.54, 1.51, 1.41, 1.48, 1.46, 1.55, 1.63, 1.42, 1.53, 1.47, 1.47, 1.62, 1.49, 2.09, 1.42, 1.48, 1.33,
|
||||
# 1.62, 1.41, 1.41, 1.45, 1.50, 1.78, 1.53, 1.56, 1.49, 1.51, 2.31, 1.40, 1.58, 1.39, 1.49, 1.51, 1.55, 1.58,
|
||||
# 1.93, 1.47, 1.41, 1.47, 1.52, 1.52, 1.39, 1.48, 1.64, 1.49, 1.47, 1.53, 1.50, 3.58, 1.54, 1.70, 1.50, 1.47,
|
||||
# 1.35, 1.51, 1.70, 1.59, 1.60, 1.56, 1.29
|
||||
# ]
|
||||
# )
|
||||
# )
|
||||
|
||||
class DINOv2a(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2a, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
# self.shifts = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["mean"]), requires_grad=False)
|
||||
# self.scales = nn.Parameter(torch.tensor(NORMALIZE_DATA[weight_path+'a']["std"]), requires_grad=False)
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
# feature = (feature - self.shifts.view(1, 1, -1)) / self.scales.view(1, 1, -1)
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = torch.nn.functional.fold(feature, (patch_num_h*2, patch_num_w*2), kernel_size=2, stride=2)
|
||||
return feature
|
||||
|
||||
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
import math
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[..., None].float() * freqs[None, ...]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
class Classifer(nn.Module):
|
||||
def __init__(self, in_channels=192, hidden_size=256, num_classes=1000):
|
||||
super(Classifer, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.feature_x = nn.Sequential(
|
||||
nn.Conv2d(kernel_size=2, in_channels=in_channels, out_channels=num_classes, stride=2, padding=0),
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
)
|
||||
def forward(self, xt):
|
||||
xt = xt[:, :self.in_channels]
|
||||
score = self.feature_x(xt).squeeze(-1).squeeze(-1)
|
||||
# score = (feature_xt).clamp(-5, 5)
|
||||
score = torch.softmax(score, dim=1)
|
||||
return score
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 64
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=4, num_workers=4, drop_last=True)
|
||||
dino = DINOv2a("dinov2_vitb14")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
rank = accelerator.process_index
|
||||
|
||||
classifer = Classifer(in_channels=32)
|
||||
classifer.train()
|
||||
optimizer = torch.optim.Adam(classifer.parameters(), lr=0.0001)
|
||||
|
||||
dino, dataloader, classifer, optimizer = accelerator.prepare(dino, dataloader, classifer, optimizer)
|
||||
|
||||
# fake_file_dir = "/mnt/bn/wangshuai6/data/gan_guidance"
|
||||
# fake_file_names = os.listdir(fake_file_dir)
|
||||
|
||||
for epoch in range(100):
|
||||
for i, (true_images, true_labels, path_list) in enumerate(dataloader):
|
||||
batch_size = true_images.shape[0]
|
||||
true_labels = true_labels.to(accelerator.device)
|
||||
true_labels = torch.nn.functional.one_hot(true_labels, num_classes=1000)
|
||||
with torch.no_grad():
|
||||
true_dino_feature = dino(true_images)
|
||||
# t = torch.rand((batch_size, 1, 1, 1), device=accelerator.device)
|
||||
# true_x_t = t * true_dino_feature + (1-t) * noise
|
||||
|
||||
true_x_t = true_dino_feature
|
||||
true_score = classifer(true_x_t)
|
||||
|
||||
# ind = i % len(fake_file_names)
|
||||
# fake_file = torch.load(os.path.join(fake_file_dir, fake_file_names[ind]))
|
||||
# import pdb; pdb.set_trace()
|
||||
# ind = torch.randint(0, 50, size=(4,))
|
||||
# fake_x_t = fake_file['trajs'][ind].view(-1, 196, 32, 32)[:, 4:, :, :]
|
||||
# fake_labels = fake_file['condition'].repeat(4)
|
||||
# fake_score = classifer(fake_x_t)
|
||||
|
||||
loss_true = -torch.log(true_score)*true_labels
|
||||
loss = loss_true.sum()/batch_size
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
acc = torch.sum(torch.argmax(true_score, dim=1) == torch.argmax(true_labels, dim=1))/batch_size
|
||||
if accelerator.is_main_process:
|
||||
print("epoch:{}".format(epoch), "iter:{}".format(i), "loss:{}".format(loss.item()), "acc:{}".format(acc.item()))
|
||||
if accelerator.is_main_process:
|
||||
torch.save(classifer.state_dict(), f'{epoch}.pth')
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -0,0 +1,4 @@
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat
|
||||
pip3 install -r requirements.txt
|
||||
git branch --set-upstream-to=origin/master master
|
||||
git pull
|
||||
@@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
# tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 4096
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
dino = DINOv2("dinov2_vitb14")
|
||||
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
acc_mean = torch.zeros((768, ), device=accelerator.device)
|
||||
acc_num = 0
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
acc_num += len(images)
|
||||
feature = dino(images)
|
||||
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
|
||||
for std in stds:
|
||||
print("{:.2f},".format(std), end='')
|
||||
print()
|
||||
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
|
||||
for mean in means:
|
||||
print("{:.2f},".format(mean), end='')
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -0,0 +1,168 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 2048
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
dino = DINOv2("dinov2_vitb14")
|
||||
# dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
feature = dino(images)
|
||||
b, c, h, w = feature.shape
|
||||
feature = feature.view(b, c, h*w).transpose(1, 2)
|
||||
feature = feature.reshape(-1, c)
|
||||
U, S, V = torch.pca_lowrank(feature, 64, )
|
||||
import pdb; pdb.set_trace()
|
||||
feature = torch.matmul(feature, V)
|
||||
break
|
||||
accelerator.wait_for_everyone()
|
||||
+64
@@ -0,0 +1,64 @@
|
||||
import matplotlib.pyplot as plt
|
||||
print(len([0, 3, 6, 9, 12, 16, 20, 24, 28, 33, 38, 43, 48, 53, 57, 62, 67, 72, 78, 83, 87, 91, 95, 98, 102, 106, 110, 115, 120, 125, 130, 135, 141, 146, 152, 158, 164, 171, 179, 185, 191, 197, 203, 209, 216, 223, 229, 234, 240, 245, 250]))
|
||||
print(len(list(range(0, 251, 5))))
|
||||
exit()
|
||||
plt.plot()
|
||||
plt.plot()
|
||||
plt.show()
|
||||
exit()
|
||||
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
num_steps = 10
|
||||
num_recompute_timesteps = 4
|
||||
sim = torch.randint(0, 100, (num_steps, num_steps))
|
||||
sim[:5, :5] = 100
|
||||
for i in range(num_steps):
|
||||
sim[i, i] = 100
|
||||
|
||||
error_map = (100-sim).tolist()
|
||||
|
||||
|
||||
# init
|
||||
for i in range(1, num_steps):
|
||||
for j in range(0, i):
|
||||
error_map[i][j] = error_map[i-1][j] + error_map[i][j]
|
||||
|
||||
C = [[0, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
|
||||
P = [[-1, ] * (num_steps + 1) for _ in range(num_recompute_timesteps+1)]
|
||||
|
||||
for i in range(1, num_steps+1):
|
||||
C[1][i] = error_map[i-1][0]
|
||||
P[1][i] = 0
|
||||
|
||||
|
||||
# dp
|
||||
for step in range(2, num_recompute_timesteps+1):
|
||||
for i in range(step, num_steps+1):
|
||||
min_value = 99999
|
||||
min_index = -1
|
||||
for j in range(step-1, i):
|
||||
value = C[step-1][j] + error_map[i-1][j]
|
||||
if value < min_value:
|
||||
min_value = value
|
||||
min_index = j
|
||||
C[step][i] = min_value
|
||||
P[step][i] = min_index
|
||||
|
||||
# trace back
|
||||
tracback_end_index = num_steps
|
||||
# min_value = 99999
|
||||
# for i in range(num_recompute_timesteps-1, num_steps):
|
||||
# if C[-1][i] < min_value:
|
||||
# min_value = C[-1][i]
|
||||
# tracback_end_index = i
|
||||
|
||||
timesteps = [tracback_end_index, ]
|
||||
for i in range(num_recompute_timesteps, 0, -1):
|
||||
idx = timesteps[-1]
|
||||
timesteps.append(P[i][idx])
|
||||
timesteps.reverse()
|
||||
print(timesteps)
|
||||
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
|
||||
"6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
|
||||
"8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
|
||||
"10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
|
||||
"6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
|
||||
"8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
|
||||
"10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
|
||||
"6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
|
||||
"8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
|
||||
"10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
|
||||
"6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
|
||||
"8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
|
||||
"10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
|
||||
"6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
|
||||
"8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
|
||||
"10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
|
||||
plt.legend(fontsize="14")
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
|
||||
plt.close()
|
||||
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
|
||||
"6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
|
||||
"8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
|
||||
"6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
|
||||
"8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
|
||||
}
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
|
||||
"6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
|
||||
"8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
|
||||
"6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
|
||||
"8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
|
||||
"6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
|
||||
"8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
metric_data = {
|
||||
"FID" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
|
||||
plt.legend()
|
||||
plt.xticks(x)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
cfg_data = {
|
||||
"[0, 1]":{
|
||||
"cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
|
||||
"FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
|
||||
},
|
||||
"[0.2, 1]":{
|
||||
"cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
|
||||
"FID": [5.87, 4.44, 3.96, 4.01, 4.26]
|
||||
},
|
||||
"[0.3, 1]":{
|
||||
"cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
|
||||
"FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
|
||||
},
|
||||
"[0.35, 1]":{
|
||||
"cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
|
||||
"FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
|
||||
}
|
||||
}
|
||||
|
||||
colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
for i, (name, data) in enumerate(cfg_data.items()):
|
||||
plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
|
||||
|
||||
plt.title("Classifer-free guidance with intervals", weight="bold")
|
||||
plt.ylabel("FID10K", weight="bold")
|
||||
plt.xlabel("CFG values", weight="bold")
|
||||
plt.legend()
|
||||
plt.savefig("./output/cfg.pdf", bbox_inches="tight")
|
||||
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
|
||||
states = states.permute(1, 2, 0, 3)
|
||||
print(states.shape)
|
||||
states = states.view(-1, 49, 1152)
|
||||
states = torch.nn.functional.normalize(states, dim=-1)
|
||||
sim = torch.bmm(states, states.transpose(1, 2))
|
||||
mean_sim = torch.mean(sim, dim=0, keepdim=False)
|
||||
|
||||
mean_sim = mean_sim.numpy()
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
timesteps = np.linspace(0, 1, 5)
|
||||
# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
|
||||
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
|
||||
plt.imshow(mean_sim, cmap="inferno")
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
# plt.show()
|
||||
plt.colorbar()
|
||||
plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states, states)
|
||||
|
||||
|
||||
# for i in range(49):
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
|
||||
# cos_sim = cos_sim.min()
|
||||
# print(cos_sim)
|
||||
# state = torch.max(states, dim=-1)[1]
|
||||
# # state = torch.softmax(state, dim=-1)
|
||||
# state = state.view(-1, 16, 16)
|
||||
#
|
||||
# state = state.numpy()
|
||||
#
|
||||
# import numpy as np
|
||||
# import matplotlib.pyplot as plt
|
||||
# for i in range(0, 49):
|
||||
# print(i)
|
||||
# plt.imshow(state[i])
|
||||
# plt.savefig("./output2/{}.png".format(i))
|
||||
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
|
||||
"12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
|
||||
"16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
|
||||
"20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
|
||||
"12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
|
||||
"16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
|
||||
"20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
|
||||
"12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
|
||||
"16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
|
||||
"20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
|
||||
"12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
|
||||
"16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
|
||||
"20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
|
||||
"12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
|
||||
"16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
|
||||
"20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
|
||||
plt.legend(fontsize="14")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
@@ -0,0 +1,18 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
t = np.linspace(0.001, 0.999, 100)
|
||||
def snr(t):
|
||||
return np.log((1-t)/t)
|
||||
def pds(t):
|
||||
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
|
||||
print(pds(t))
|
||||
plt.figure(figsize=(16, 4))
|
||||
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
|
||||
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
|
||||
plt.ylabel("log-SNR", weight="bold")
|
||||
plt.xlabel("Timesteps", weight="bold")
|
||||
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
plt.gca().invert_xaxis()
|
||||
plt.show()
|
||||
# plt.savefig("output/logsnr.pdf", bbox_inches='tight')
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
data = {
|
||||
"SiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.06,
|
||||
"color": "#ff99c8"
|
||||
},
|
||||
"DiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.27,
|
||||
"color": "#fcf6bd"
|
||||
},
|
||||
"REPA-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 800,
|
||||
"FID": 1.42,
|
||||
"color": "#d0f4de"
|
||||
},
|
||||
# "MAR-H" : {
|
||||
# "size": 973,
|
||||
# "epochs": 800,
|
||||
# "FID": 1.55,
|
||||
# },
|
||||
"MDTv2" : {
|
||||
"size": 675,
|
||||
"epochs": 920,
|
||||
"FID": 1.58,
|
||||
"color": "#e4c1f9"
|
||||
},
|
||||
# "VAVAE+LightningDiT" : {
|
||||
# "size": 675,
|
||||
# "epochs": [64, 800],
|
||||
# "FID": [2.11, 1.35],
|
||||
# },
|
||||
"DDT-XL/2": {
|
||||
"size": 675,
|
||||
"epochs": [80, 256],
|
||||
"FID": [1.52, 1.31],
|
||||
"color": "#38a3a5"
|
||||
},
|
||||
"DDT-L/2": {
|
||||
"size": 400,
|
||||
"epochs": 80,
|
||||
"FID": 1.64,
|
||||
"color": "#5bc0be"
|
||||
},
|
||||
}
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
for k, spec in data.items():
|
||||
plt.scatter(
|
||||
# spec["size"],
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
label=k,
|
||||
marker="o",
|
||||
s=spec["size"],
|
||||
color=spec["color"],
|
||||
)
|
||||
x = spec["epochs"]
|
||||
y = spec["FID"]
|
||||
if isinstance(spec["FID"], list):
|
||||
x = spec["epochs"][-1]
|
||||
y = spec["FID"][-1]
|
||||
plt.plot(
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
color=spec["color"],
|
||||
linestyle="dotted",
|
||||
linewidth=4
|
||||
)
|
||||
# plt.annotate("",
|
||||
# xytext=(spec["epochs"][0], spec["FID"][0]),
|
||||
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
|
||||
plt.text(x+80, y-0.05, k, fontsize=13)
|
||||
|
||||
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
|
||||
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
|
||||
|
||||
plt.annotate("",
|
||||
xy=(700, 1.42), xytext=(200, 1.42),
|
||||
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
|
||||
)
|
||||
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.gca().set_xlim(0, 1800)
|
||||
plt.gca().set_ylim(1.15, 2.5)
|
||||
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
|
||||
plt.xlabel("Training Epochs", weight="bold")
|
||||
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
|
||||
plt.savefig("output/sota.pdf", bbox_inches="tight")
|
||||
@@ -0,0 +1,26 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
t = np.linspace(0, 1, 100)
|
||||
shifts = [1.0, 1.5, 2, 3]
|
||||
for i , shift in enumerate(shifts):
|
||||
plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
|
||||
|
||||
# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.title("Respaced timesteps with various shift value", weight="bold")
|
||||
# plt.gca().set_xlim(0, 1.0)
|
||||
# plt.gca().set_ylim(0, 1.0)
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
|
||||
plt.ylabel("Respaced Timesteps", weight="bold")
|
||||
plt.xlabel("Uniform Timesteps", weight="bold")
|
||||
plt.legend(loc="upper left", fontsize="12")
|
||||
plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)
|
||||
@@ -0,0 +1,29 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
data = {
|
||||
"shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
|
||||
"shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
|
||||
"shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
|
||||
"shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
|
||||
}
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
steps = [5, 6, 7, 8, 9, 10, 12]
|
||||
for i ,(k, v)in enumerate(data.items()):
|
||||
plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
|
||||
|
||||
# plt.title("FID50K of different steps of different timeshift", weight="bold")
|
||||
plt.ylabel("FID50K", weight="bold")
|
||||
plt.xlabel("Num of inference steps", weight="bold")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
# plt.legend()
|
||||
# plt.legend()
|
||||
plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)
|
||||
@@ -0,0 +1,21 @@
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
PATH = "/mnt/bn/wangshuai6/neural_sampling_workdirs/expbaseline_adam2_timeshift1.5"
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
PATH = pathlib.Path(PATH)
|
||||
images = []
|
||||
|
||||
# find images
|
||||
for ext in IMG_EXTENSIONS:
|
||||
images.extend(PATH.rglob(ext))
|
||||
|
||||
for image in images:
|
||||
os.system(f"rm -f {image}")
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import time
|
||||
import torch.nn as nn
|
||||
import accelerate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = nn.Linear(512, 512)
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
accelerator = accelerate.Accelerator()
|
||||
model = accelerator.prepare_model(model)
|
||||
model.to(accelerator.device)
|
||||
data = torch.randn(1024, 512).to(accelerator.device)
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
accelerator.wait_for_everyone()
|
||||
if torch.cuda.utilization() < 1.5:
|
||||
with torch.no_grad():
|
||||
model(data)
|
||||
else:
|
||||
time.sleep(1)
|
||||
# print(f"rank:{accelerator.process_index}->usage:{torch.cuda.utilization()}")
|
||||
@@ -0,0 +1,20 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def lw(x, b=0):
|
||||
x = np.clip(x, a_min=0.001, a_max=0.999)
|
||||
snr = x/(1-x)
|
||||
logsnr = np.log(snr)
|
||||
# print(logsnr)
|
||||
# return logsnr
|
||||
weight = 1 / (1 + np.exp(-logsnr - b))#*(1-x)**2
|
||||
return weight #/weight.max()
|
||||
|
||||
x = np.arange(0.2, 0.8, 0.001)
|
||||
print(1/(x*(1-x)))
|
||||
for b in [0, 1, 2, 3]:
|
||||
y = lw(x, b)
|
||||
plt.plot(x, y, label=f"b={b}")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
@@ -0,0 +1,173 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import timm
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from torchvision.transforms import Normalize
|
||||
import copy
|
||||
|
||||
class DINOv2(nn.Module):
|
||||
def __init__(self, weight_path:str):
|
||||
super(DINOv2, self).__init__()
|
||||
self.encoder = torch.hub.load('facebookresearch/dinov2', weight_path)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.precomputed_pos_embed = dict()
|
||||
|
||||
def fetch_pos(self, h, w):
|
||||
key = (h, w)
|
||||
if key in self.precomputed_pos_embed:
|
||||
return self.precomputed_pos_embed[key]
|
||||
value = timm.layers.pos_embed.resample_abs_pos_embed(
|
||||
self.pos_embed.data, [h, w],
|
||||
)
|
||||
self.precomputed_pos_embed[key] = value
|
||||
return value
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (int(224*h/256), int(224*w/256)), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
pos_embed_data = self.fetch_pos(patch_num_h, patch_num_w)
|
||||
self.encoder.pos_embed.data = pos_embed_data
|
||||
feature = self.encoder.forward_features(x)['x_norm_patchtokens']
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
return feature
|
||||
|
||||
class MAE(nn.Module):
|
||||
def __init__(self, model_id, weight_path:str):
|
||||
super(MAE, self).__init__()
|
||||
if os.path.isdir(weight_path):
|
||||
weight_path = os.path.join(weight_path, "pytorch_model.bin")
|
||||
self.encoder = timm.create_model(
|
||||
model_id,
|
||||
checkpoint_path=weight_path,
|
||||
num_classes=0,
|
||||
)
|
||||
self.pos_embed = copy.deepcopy(self.encoder.pos_embed)
|
||||
self.encoder.head = torch.nn.Identity()
|
||||
self.patch_size = self.encoder.patch_embed.patch_size
|
||||
self.shifts = nn.Parameter(torch.tensor([0.0
|
||||
]), requires_grad=False)
|
||||
self.scales = nn.Parameter(torch.tensor([1.0
|
||||
]), requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
||||
x = torch.nn.functional.interpolate(x, (224, 224), mode='bicubic')
|
||||
b, c, h, w = x.shape
|
||||
patch_num_h, patch_num_w = h//self.patch_size[0], w//self.patch_size[1]
|
||||
feature = self.encoder.forward_features(x)[:, self.encoder.num_prefix_tokens:]
|
||||
feature = feature.transpose(1, 2)
|
||||
feature = feature.view(b, -1, patch_num_h, patch_num_w).contiguous()
|
||||
feature = (feature - self.shifts.view(1, -1, 1, 1)) / self.scales.view(1, -1, 1, 1)
|
||||
return feature
|
||||
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from torchvision.datasets import ImageFolder, ImageNet
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torchvision.transforms as tvtf
|
||||
|
||||
IMG_EXTENSIONS = (
|
||||
"*.png",
|
||||
"*.JPEG",
|
||||
"*.jpeg",
|
||||
"*.jpg"
|
||||
)
|
||||
|
||||
class NewImageFolder(ImageFolder):
|
||||
def __getitem__(self, item):
|
||||
path, target = self.samples[item]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
return sample, target, path
|
||||
|
||||
|
||||
import time
|
||||
class CenterCrop:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def __call__(self, image):
|
||||
def center_crop_arr(pil_image, image_size):
|
||||
"""
|
||||
Center cropping implementation from ADM.
|
||||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
||||
"""
|
||||
while min(*pil_image.size) >= 2 * image_size:
|
||||
pil_image = pil_image.resize(
|
||||
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
||||
)
|
||||
|
||||
scale = image_size / min(*pil_image.size)
|
||||
pil_image = pil_image.resize(
|
||||
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
||||
)
|
||||
|
||||
arr = np.array(pil_image)
|
||||
crop_y = (arr.shape[0] - image_size) // 2
|
||||
crop_x = (arr.shape[1] - image_size) // 2
|
||||
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
||||
|
||||
return center_crop_arr(image, self.size)
|
||||
|
||||
def save(obj, path):
|
||||
dirname = os.path.dirname(path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
torch.save(obj, f'{path}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch_hub_dir = '/mnt/bn/wangshuai6/torch_hub'
|
||||
os.environ["TORCH_HOME"] = torch_hub_dir
|
||||
torch.hub.set_dir(torch_hub_dir)
|
||||
|
||||
for split in ['train']:
|
||||
train = split == 'train'
|
||||
transforms = tvtf.Compose([
|
||||
CenterCrop(256),
|
||||
tvtf.ToTensor(),
|
||||
tvtf.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
||||
])
|
||||
dataset = NewImageFolder(root='/mnt/bn/wangshuai6/data/ImageNet/train', transform=transforms,)
|
||||
B = 4
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=B, shuffle=True, prefetch_factor=None, num_workers=0)
|
||||
# dino = DINOv2("dinov2_vitb14")
|
||||
dino = MAE("vit_base_patch16_224.mae", "/mnt/bn/wangshuai6/models/vit_base_patch16_224.mae")
|
||||
# dino = CLIP("/mnt/bn/wangshuai6/models/vit_base_patch16_clip_224.openai")
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator()
|
||||
dino, dataloader = accelerator.prepare(dino, dataloader)
|
||||
rank = accelerator.process_index
|
||||
|
||||
acc_mean = torch.zeros((768, ), device=accelerator.device)
|
||||
acc_num = 0
|
||||
with torch.no_grad():
|
||||
for i, (images, labels, path_list) in enumerate(dataloader):
|
||||
acc_num += len(images)
|
||||
feature = dino(images)
|
||||
stds = torch.std(feature, dim=[0, 2, 3]).tolist()
|
||||
for std in stds:
|
||||
print("{:.2f},".format(std), end='')
|
||||
print()
|
||||
means = torch.mean(feature, dim=[0, 2, 3]).tolist()
|
||||
for mean in means:
|
||||
print("{:.2f},".format(mean), end='')
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -0,0 +1,23 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
def gaussian(t):
|
||||
gs = 1+scipy.special.erf((t-t.mean())/t.std())
|
||||
|
||||
def rs2(t, s=2.0):
|
||||
factor1 = 1.0 #s/(s+(1-s)*t)**2
|
||||
factor2 = np.log(t.clip(0.001, 0.999)/(1-t).clip(0.001, 0.999))
|
||||
return factor1*factor2
|
||||
|
||||
|
||||
t = np.linspace(0, 1, 100)
|
||||
# plt.plot(t, timeshift(t, 1.0))
|
||||
respaced_t = timeshift(t, s=5)
|
||||
delats = (respaced_t[1:] - respaced_t[:-1])
|
||||
# plt.plot(t, timeshift(t, 1.5))
|
||||
plt.plot(rs2(t))
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user