submit code

This commit is contained in:
wangshuai6
2025-04-09 11:01:16 +08:00
parent 4fbcf9bd87
commit 06499f1caa
145 changed files with 14400 additions and 0 deletions

64
tools/figures/base++.py Normal file
View File

@@ -0,0 +1,64 @@
import numpy as np
import matplotlib.pyplot as plt
is_data = {
"4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
"6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
"8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
"10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
}
fid_data = {
"4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
"6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
"8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
"10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
}
sfid_data = {
"4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
"6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
"8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
"10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
}
pr_data = {
"4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
"6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
"8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
"10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
}
recall_data = {
"4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
"6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
"8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
"10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
}
x = [100, 200, 300, 400]
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
metric_data = {
"FID50K" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
plt.legend(fontsize="14")
plt.xticks([100, 150, 200, 250, 300, 350, 400])
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
plt.close()

57
tools/figures/base.py Normal file
View File

@@ -0,0 +1,57 @@
import numpy as np
import matplotlib.pyplot as plt
fid_data = {
"4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
"6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
"8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
}
sfid_data = {
"4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
"6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
"8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
}
is_data = {
"4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
"6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
"8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
}
pr_data = {
"4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
"6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
"8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
}
recall_data = {
"4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
"6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
"8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
}
x = [100, 200, 300, 400]
colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
metric_data = {
"FID" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
plt.legend()
plt.xticks(x)
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
plt.close()

32
tools/figures/cfg.py Normal file
View File

@@ -0,0 +1,32 @@
import numpy as np
import matplotlib.pyplot as plt
cfg_data = {
"[0, 1]":{
"cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
"FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
},
"[0.2, 1]":{
"cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
"FID": [5.87, 4.44, 3.96, 4.01, 4.26]
},
"[0.3, 1]":{
"cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
"FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
},
"[0.35, 1]":{
"cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
"FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
}
}
colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
for i, (name, data) in enumerate(cfg_data.items()):
plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
plt.title("Classifer-free guidance with intervals", weight="bold")
plt.ylabel("FID10K", weight="bold")
plt.xlabel("CFG values", weight="bold")
plt.legend()
plt.savefig("./output/cfg.pdf", bbox_inches="tight")

42
tools/figures/feat_vis.py Normal file
View File

@@ -0,0 +1,42 @@
import torch
states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
states = states.permute(1, 2, 0, 3)
print(states.shape)
states = states.view(-1, 49, 1152)
states = torch.nn.functional.normalize(states, dim=-1)
sim = torch.bmm(states, states.transpose(1, 2))
mean_sim = torch.mean(sim, dim=0, keepdim=False)
mean_sim = mean_sim.numpy()
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
timesteps = np.linspace(0, 1, 5)
# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
plt.imshow(mean_sim, cmap="inferno")
plt.xticks([])
plt.yticks([])
# plt.show()
plt.colorbar()
plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
# cos_sim = torch.nn.functional.cosine_similarity(states, states)
# for i in range(49):
# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
# cos_sim = cos_sim.min()
# print(cos_sim)
# state = torch.max(states, dim=-1)[1]
# # state = torch.softmax(state, dim=-1)
# state = state.view(-1, 16, 16)
#
# state = state.numpy()
#
# import numpy as np
# import matplotlib.pyplot as plt
# for i in range(0, 49):
# print(i)
# plt.imshow(state[i])
# plt.savefig("./output2/{}.png".format(i))

63
tools/figures/large++.py Normal file
View File

@@ -0,0 +1,63 @@
import numpy as np
import matplotlib.pyplot as plt
is_data = {
"10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
"12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
"16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
"20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
}
fid_data = {
"10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
"12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
"16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
"20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
}
sfid_data = {
"10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
"12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
"16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
"20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
}
pr_data = {
"10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
"12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
"16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
"20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
}
recall_data = {
"10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
"12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
"16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
"20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
}
x = [100, 200, 300, 400]
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
metric_data = {
"FID50K" : fid_data,
# "SFID" : sfid_data,
"InceptionScore" : is_data,
"Precision" : pr_data,
"Recall" : recall_data,
}
for key, data in metric_data.items():
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
for i, (name, v) in enumerate(data.items()):
name = name.replace("encoder", "En")
name = name.replace("decoder", "De")
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
plt.legend(fontsize="14")
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.xticks([100, 150, 200, 250, 300, 350, 400])
plt.ylabel(key, weight="bold")
plt.xlabel("Training iterations(K steps)", weight="bold")
plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
plt.close()

18
tools/figures/log_snr.py Normal file
View File

@@ -0,0 +1,18 @@
import numpy as np
import matplotlib.pyplot as plt
t = np.linspace(0.001, 0.999, 100)
def snr(t):
return np.log((1-t)/t)
def pds(t):
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
print(pds(t))
plt.figure(figsize=(16, 4))
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
plt.ylabel("log-SNR", weight="bold")
plt.xlabel("Timesteps", weight="bold")
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
plt.gca().invert_xaxis()
plt.show()
# plt.savefig("output/logsnr.pdf", bbox_inches='tight')

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

95
tools/figures/sota.py Normal file
View File

@@ -0,0 +1,95 @@
import numpy as np
import matplotlib.pyplot as plt
data = {
"SiT-XL/2" : {
"size": 675,
"epochs": 1400,
"FID": 2.06,
"color": "#ff99c8"
},
"DiT-XL/2" : {
"size": 675,
"epochs": 1400,
"FID": 2.27,
"color": "#fcf6bd"
},
"REPA-XL/2" : {
"size": 675,
"epochs": 800,
"FID": 1.42,
"color": "#d0f4de"
},
# "MAR-H" : {
# "size": 973,
# "epochs": 800,
# "FID": 1.55,
# },
"MDTv2" : {
"size": 675,
"epochs": 920,
"FID": 1.58,
"color": "#e4c1f9"
},
# "VAVAE+LightningDiT" : {
# "size": 675,
# "epochs": [64, 800],
# "FID": [2.11, 1.35],
# },
"DDT-XL/2": {
"size": 675,
"epochs": [80, 256],
"FID": [1.52, 1.31],
"color": "#38a3a5"
},
"DDT-L/2": {
"size": 400,
"epochs": 80,
"FID": 1.64,
"color": "#5bc0be"
},
}
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for k, spec in data.items():
plt.scatter(
# spec["size"],
spec["epochs"],
spec["FID"],
label=k,
marker="o",
s=spec["size"],
color=spec["color"],
)
x = spec["epochs"]
y = spec["FID"]
if isinstance(spec["FID"], list):
x = spec["epochs"][-1]
y = spec["FID"][-1]
plt.plot(
spec["epochs"],
spec["FID"],
color=spec["color"],
linestyle="dotted",
linewidth=4
)
# plt.annotate("",
# xytext=(spec["epochs"][0], spec["FID"][0]),
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
plt.text(x+80, y-0.05, k, fontsize=13)
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
plt.annotate("",
xy=(700, 1.42), xytext=(200, 1.42),
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
)
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.gca().set_xlim(0, 1800)
plt.gca().set_ylim(1.15, 2.5)
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
plt.xlabel("Training Epochs", weight="bold")
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
plt.savefig("output/sota.pdf", bbox_inches="tight")

View File

@@ -0,0 +1,26 @@
import scipy
import numpy as np
import matplotlib.pyplot as plt
def timeshift(t, s=1.0):
return t/(t+(1-t)*s)
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
t = np.linspace(0, 1, 100)
shifts = [1.0, 1.5, 2, 3]
for i , shift in enumerate(shifts):
plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
# plt.title("Respaced timesteps with various shift value", weight="bold")
# plt.gca().set_xlim(0, 1.0)
# plt.gca().set_ylim(0, 1.0)
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
plt.ylabel("Respaced Timesteps", weight="bold")
plt.xlabel("Uniform Timesteps", weight="bold")
plt.legend(loc="upper left", fontsize="12")
plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)

View File

@@ -0,0 +1,29 @@
import scipy
import numpy as np
import matplotlib.pyplot as plt
def timeshift(t, s=1.0):
return t/(t+(1-t)*s)
data = {
"shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
"shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
"shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
"shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
}
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
steps = [5, 6, 7, 8, 9, 10, 12]
for i ,(k, v)in enumerate(data.items()):
plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
# plt.title("FID50K of different steps of different timeshift", weight="bold")
plt.ylabel("FID50K", weight="bold")
plt.xlabel("Num of inference steps", weight="bold")
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
# plt.legend()
# plt.legend()
plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)