submit code
This commit is contained in:
64
tools/figures/base++.py
Normal file
64
tools/figures/base++.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[46.01, 61.47, 69.73, 74.26],
|
||||
"6encoder6decoder":[53.11, 71.04, 79.83, 83.85],
|
||||
"8encoder4decoder":[54.06, 72.96, 80.49, 85.94],
|
||||
"10encoder2decoder": [49.25, 67.59, 76.00, 81.12],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[31.40, 22.80, 20.13, 18.61],
|
||||
"6encoder6decoder":[27.61, 20.42, 17.95, 16.86],
|
||||
"8encoder4decoder":[27.12, 19.90, 17.78, 16.32],
|
||||
"10encoder2decoder": [29.70, 21.75, 18.95, 17.65],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[6.88, 6.44, 6.56, 6.56],
|
||||
"6encoder4decoder":[6.83, 6.50, 6.49, 6.63],
|
||||
"8encoder4decoder":[6.76, 6.70, 6.83, 6.63],
|
||||
"10encoder2decoder": [6.81, 6.61, 6.53, 6.60],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.55006, 0.59538, 0.6063, 0.60922],
|
||||
"6encoder6decoder":[0.56436, 0.60246, 0.61668, 0.61702],
|
||||
"8encoder4decoder":[0.56636, 0.6038, 0.61832, 0.62132],
|
||||
"10encoder2decoder": [0.55612, 0.59846, 0.61092, 0.61686],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.6347, 0.6495, 0.6559, 0.662],
|
||||
"6encoder6decoder":[0.6477, 0.6497, 0.6594, 0.6589],
|
||||
"8encoder4decoder":[0.6403, 0.653, 0.6505, 0.6618],
|
||||
"10encoder2decoder": [0.6342, 0.6492, 0.6536, 0.6569],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=10)
|
||||
plt.legend(fontsize="14")
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base++_{}.pdf".format(key), bbox_inches='tight',)
|
||||
plt.close()
|
||||
57
tools/figures/base.py
Normal file
57
tools/figures/base.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
|
||||
|
||||
fid_data = {
|
||||
"4encoder8decoder":[64.16, 48.04, 39.88, 35.41],
|
||||
"6encoder4decoder":[67.71, 48.26, 39.30, 34.91],
|
||||
"8encoder4decoder":[69.4, 49.7, 41.56, 36.76],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"4encoder8decoder":[7.86, 7.48, 7.15, 7.07],
|
||||
"6encoder4decoder":[8.54, 8.11, 7.40, 7.40],
|
||||
"8encoder4decoder":[8.42, 8.27, 8.10, 7.69],
|
||||
}
|
||||
|
||||
is_data = {
|
||||
"4encoder8decoder":[20.37, 29.41, 36.88, 41.32],
|
||||
"6encoder4decoder":[20.04, 30.13, 38.17, 43.84],
|
||||
"8encoder4decoder":[19.98, 29.54, 35.93, 42.025],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"4encoder8decoder":[0.3935, 0.4687, 0.5047, 0.5271],
|
||||
"6encoder4decoder":[0.3767, 0.4686, 0.50876, 0.5266],
|
||||
"8encoder4decoder":[0.37, 0.45676, 0.49602, 0.5162],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"4encoder8decoder":[0.5604, 0.5941, 0.6244, 0.6338],
|
||||
"6encoder4decoder":[0.5295, 0.595, 0.6287, 0.6378],
|
||||
"8encoder4decoder":[0.51, 0.596, 0.6242, 0.6333],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
metric_data = {
|
||||
"FID" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=3, marker="o")
|
||||
plt.legend()
|
||||
plt.xticks(x)
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/base_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
32
tools/figures/cfg.py
Normal file
32
tools/figures/cfg.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
cfg_data = {
|
||||
"[0, 1]":{
|
||||
"cfg":[1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
|
||||
"FID":[9.23, 6.61, 5.08, 4.46, 4.32, 4.52, 4.86, 5.38, 5.97, 6.57, 7.13],
|
||||
},
|
||||
"[0.2, 1]":{
|
||||
"cfg": [1.2, 1.4, 1.6, 1.8, 2.0],
|
||||
"FID": [5.87, 4.44, 3.96, 4.01, 4.26]
|
||||
},
|
||||
"[0.3, 1]":{
|
||||
"cfg": [1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4],
|
||||
"FID": [4.31, 4.11, 3.98, 3.89, 3.87, 3.88, 3.91, 3.96, 4.03]
|
||||
},
|
||||
"[0.35, 1]":{
|
||||
"cfg": [1.6, 1.8, 2.0, 2.1, 2.2, 2.3, 2.4, 2.6],
|
||||
"FID": [4.68, 4.22, 3.98, 3.92, 3.90, 3.88, 3.88, 3.94]
|
||||
}
|
||||
}
|
||||
|
||||
colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
for i, (name, data) in enumerate(cfg_data.items()):
|
||||
plt.plot(data["cfg"], data["FID"], label="Interval: " +name, color=colors[i], linewidth=3.5, marker="o")
|
||||
|
||||
plt.title("Classifer-free guidance with intervals", weight="bold")
|
||||
plt.ylabel("FID10K", weight="bold")
|
||||
plt.xlabel("CFG values", weight="bold")
|
||||
plt.legend()
|
||||
plt.savefig("./output/cfg.pdf", bbox_inches="tight")
|
||||
42
tools/figures/feat_vis.py
Normal file
42
tools/figures/feat_vis.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
states = torch.load("./output/state.pt", map_location="cpu").to(dtype=torch.float32)
|
||||
states = states.permute(1, 2, 0, 3)
|
||||
print(states.shape)
|
||||
states = states.view(-1, 49, 1152)
|
||||
states = torch.nn.functional.normalize(states, dim=-1)
|
||||
sim = torch.bmm(states, states.transpose(1, 2))
|
||||
mean_sim = torch.mean(sim, dim=0, keepdim=False)
|
||||
|
||||
mean_sim = mean_sim.numpy()
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
timesteps = np.linspace(0, 1, 5)
|
||||
# plt.rc('axes.spines', **{'bottom':False, 'left':False, 'right':False, 'top':False})
|
||||
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["#7400b8","#5e60ce","#4ea8de", "#64dfdf", "#80ffdb"])
|
||||
plt.imshow(mean_sim, cmap="inferno")
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
# plt.show()
|
||||
plt.colorbar()
|
||||
plt.savefig("./output/mean_sim.png", pad_inches=0, bbox_inches="tight")
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states, states)
|
||||
|
||||
|
||||
# for i in range(49):
|
||||
# cos_sim = torch.nn.functional.cosine_similarity(states[i], states[i + 1])
|
||||
# cos_sim = cos_sim.min()
|
||||
# print(cos_sim)
|
||||
# state = torch.max(states, dim=-1)[1]
|
||||
# # state = torch.softmax(state, dim=-1)
|
||||
# state = state.view(-1, 16, 16)
|
||||
#
|
||||
# state = state.numpy()
|
||||
#
|
||||
# import numpy as np
|
||||
# import matplotlib.pyplot as plt
|
||||
# for i in range(0, 49):
|
||||
# print(i)
|
||||
# plt.imshow(state[i])
|
||||
# plt.savefig("./output2/{}.png".format(i))
|
||||
63
tools/figures/large++.py
Normal file
63
tools/figures/large++.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
is_data = {
|
||||
"10encoder14decoder":[80.48, 104.48, 113.01, 117.29],
|
||||
"12encoder12decoder":[85.52, 109.91, 118.18, 121.77],
|
||||
"16encoder8decoder":[92.72, 116.30, 124.32, 126.37],
|
||||
"20encoder4decoder":[94.95, 117.84, 125.66, 128.30],
|
||||
}
|
||||
|
||||
fid_data = {
|
||||
"10encoder14decoder":[15.17, 10.40, 9.32, 8.66],
|
||||
"12encoder12decoder":[13.79, 9.67, 8.64, 8.21],
|
||||
"16encoder8decoder":[12.41, 8.99, 8.18, 8.03],
|
||||
"20encoder4decoder":[12.04, 8.94, 8.03, 7.98],
|
||||
}
|
||||
|
||||
sfid_data = {
|
||||
"10encoder14decoder":[5.49, 5.00, 5.09, 5.14],
|
||||
"12encoder12decoder":[5.37, 5.01, 5.07, 5.09],
|
||||
"16encoder8decoder":[5.43, 5.11, 5.20, 5.31],
|
||||
"20encoder4decoder":[5.36, 5.23, 5.21, 5.50],
|
||||
}
|
||||
|
||||
pr_data = {
|
||||
"10encoder14decoder":[0.6517, 0.67914, 0.68274, 0.68104],
|
||||
"12encoder12decoder":[0.66144, 0.68146, 0.68564, 0.6823],
|
||||
"16encoder8decoder":[0.6659, 0.68342, 0.68338, 0.67912],
|
||||
"20encoder4decoder":[0.6716, 0.68088, 0.68798, 0.68098],
|
||||
}
|
||||
|
||||
recall_data = {
|
||||
"10encoder14decoder":[0.6427, 0.6512, 0.6572, 0.6679],
|
||||
"12encoder12decoder":[0.6429, 0.6561, 0.6622, 0.6693],
|
||||
"16encoder8decoder":[0.6457, 0.6547, 0.6665, 0.6773],
|
||||
"20encoder4decoder":[0.6483, 0.6612, 0.6684, 0.6711],
|
||||
}
|
||||
|
||||
x = [100, 200, 300, 400]
|
||||
# colors = ["#70d6ff", "#ff70a6", "#ff9770", "#ffd670", "#e9ff70"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
|
||||
metric_data = {
|
||||
"FID50K" : fid_data,
|
||||
# "SFID" : sfid_data,
|
||||
"InceptionScore" : is_data,
|
||||
"Precision" : pr_data,
|
||||
"Recall" : recall_data,
|
||||
}
|
||||
|
||||
for key, data in metric_data.items():
|
||||
# plt.rc('axes.spines', **{'bottom': True, 'left': True, 'right': False, 'top': False})
|
||||
for i, (name, v) in enumerate(data.items()):
|
||||
name = name.replace("encoder", "En")
|
||||
name = name.replace("decoder", "De")
|
||||
plt.plot(x, v, label=name, color=colors[i], linewidth=5.0, marker="o", markersize=8)
|
||||
plt.legend(fontsize="14")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.xticks([100, 150, 200, 250, 300, 350, 400])
|
||||
plt.ylabel(key, weight="bold")
|
||||
plt.xlabel("Training iterations(K steps)", weight="bold")
|
||||
plt.savefig("output/large++_{}.pdf".format(key), bbox_inches='tight')
|
||||
plt.close()
|
||||
18
tools/figures/log_snr.py
Normal file
18
tools/figures/log_snr.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
t = np.linspace(0.001, 0.999, 100)
|
||||
def snr(t):
|
||||
return np.log((1-t)/t)
|
||||
def pds(t):
|
||||
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
|
||||
print(pds(t))
|
||||
plt.figure(figsize=(16, 4))
|
||||
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
|
||||
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
|
||||
plt.ylabel("log-SNR", weight="bold")
|
||||
plt.xlabel("Timesteps", weight="bold")
|
||||
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
plt.gca().invert_xaxis()
|
||||
plt.show()
|
||||
# plt.savefig("output/logsnr.pdf", bbox_inches='tight')
|
||||
BIN
tools/figures/output/base++_FID.pdf
Normal file
BIN
tools/figures/output/base++_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_FID50K.pdf
Normal file
BIN
tools/figures/output/base++_FID50K.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_InceptionScore.pdf
Normal file
BIN
tools/figures/output/base++_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_Precision.pdf
Normal file
BIN
tools/figures/output/base++_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base++_Recall.pdf
Normal file
BIN
tools/figures/output/base++_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_FID.pdf
Normal file
BIN
tools/figures/output/base_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_InceptionScore.pdf
Normal file
BIN
tools/figures/output/base_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_Precision.pdf
Normal file
BIN
tools/figures/output/base_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/base_Recall.pdf
Normal file
BIN
tools/figures/output/base_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/cfg.pdf
Normal file
BIN
tools/figures/output/cfg.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_FID.pdf
Normal file
BIN
tools/figures/output/large++_FID.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_FID50K.pdf
Normal file
BIN
tools/figures/output/large++_FID50K.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_InceptionScore.pdf
Normal file
BIN
tools/figures/output/large++_InceptionScore.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_Precision.pdf
Normal file
BIN
tools/figures/output/large++_Precision.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/large++_Recall.pdf
Normal file
BIN
tools/figures/output/large++_Recall.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/logsnr.pdf
Normal file
BIN
tools/figures/output/logsnr.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/mean_sim.png
Normal file
BIN
tools/figures/output/mean_sim.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
BIN
tools/figures/output/sota.pdf
Normal file
BIN
tools/figures/output/sota.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/timeshift.pdf
Normal file
BIN
tools/figures/output/timeshift.pdf
Normal file
Binary file not shown.
BIN
tools/figures/output/timeshift_fid.pdf
Normal file
BIN
tools/figures/output/timeshift_fid.pdf
Normal file
Binary file not shown.
95
tools/figures/sota.py
Normal file
95
tools/figures/sota.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
data = {
|
||||
"SiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.06,
|
||||
"color": "#ff99c8"
|
||||
},
|
||||
"DiT-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 1400,
|
||||
"FID": 2.27,
|
||||
"color": "#fcf6bd"
|
||||
},
|
||||
"REPA-XL/2" : {
|
||||
"size": 675,
|
||||
"epochs": 800,
|
||||
"FID": 1.42,
|
||||
"color": "#d0f4de"
|
||||
},
|
||||
# "MAR-H" : {
|
||||
# "size": 973,
|
||||
# "epochs": 800,
|
||||
# "FID": 1.55,
|
||||
# },
|
||||
"MDTv2" : {
|
||||
"size": 675,
|
||||
"epochs": 920,
|
||||
"FID": 1.58,
|
||||
"color": "#e4c1f9"
|
||||
},
|
||||
# "VAVAE+LightningDiT" : {
|
||||
# "size": 675,
|
||||
# "epochs": [64, 800],
|
||||
# "FID": [2.11, 1.35],
|
||||
# },
|
||||
"DDT-XL/2": {
|
||||
"size": 675,
|
||||
"epochs": [80, 256],
|
||||
"FID": [1.52, 1.31],
|
||||
"color": "#38a3a5"
|
||||
},
|
||||
"DDT-L/2": {
|
||||
"size": 400,
|
||||
"epochs": 80,
|
||||
"FID": 1.64,
|
||||
"color": "#5bc0be"
|
||||
},
|
||||
}
|
||||
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(1, 1, 1)
|
||||
for k, spec in data.items():
|
||||
plt.scatter(
|
||||
# spec["size"],
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
label=k,
|
||||
marker="o",
|
||||
s=spec["size"],
|
||||
color=spec["color"],
|
||||
)
|
||||
x = spec["epochs"]
|
||||
y = spec["FID"]
|
||||
if isinstance(spec["FID"], list):
|
||||
x = spec["epochs"][-1]
|
||||
y = spec["FID"][-1]
|
||||
plt.plot(
|
||||
spec["epochs"],
|
||||
spec["FID"],
|
||||
color=spec["color"],
|
||||
linestyle="dotted",
|
||||
linewidth=4
|
||||
)
|
||||
# plt.annotate("",
|
||||
# xytext=(spec["epochs"][0], spec["FID"][0]),
|
||||
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
|
||||
plt.text(x+80, y-0.05, k, fontsize=13)
|
||||
|
||||
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
|
||||
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
|
||||
|
||||
plt.annotate("",
|
||||
xy=(700, 1.42), xytext=(200, 1.42),
|
||||
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
|
||||
)
|
||||
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
plt.gca().set_xlim(0, 1800)
|
||||
plt.gca().set_ylim(1.15, 2.5)
|
||||
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
|
||||
plt.xlabel("Training Epochs", weight="bold")
|
||||
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
|
||||
plt.savefig("output/sota.pdf", bbox_inches="tight")
|
||||
26
tools/figures/timeshift.py
Normal file
26
tools/figures/timeshift.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
t = np.linspace(0, 1, 100)
|
||||
shifts = [1.0, 1.5, 2, 3]
|
||||
for i , shift in enumerate(shifts):
|
||||
plt.plot(t, timeshift(t, shift), color=colors[i], label=f"shift {shift}", linewidth=4)
|
||||
|
||||
# plt.annotate("", xytext=(0, 0), xy=(0.0, 1.05), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.annotate("", xytext=(0, 0), xy=(1.05, 0.0), arrowprops=dict(arrowstyle="->"), weight="bold")
|
||||
# plt.title("Respaced timesteps with various shift value", weight="bold")
|
||||
# plt.gca().set_xlim(0, 1.0)
|
||||
# plt.gca().set_ylim(0, 1.0)
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
|
||||
plt.ylabel("Respaced Timesteps", weight="bold")
|
||||
plt.xlabel("Uniform Timesteps", weight="bold")
|
||||
plt.legend(loc="upper left", fontsize="12")
|
||||
plt.savefig("output/timeshift.pdf", bbox_inches="tight", pad_inches=0)
|
||||
29
tools/figures/timeshift_fid.py
Normal file
29
tools/figures/timeshift_fid.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import scipy
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def timeshift(t, s=1.0):
|
||||
return t/(t+(1-t)*s)
|
||||
|
||||
data = {
|
||||
"shift 1.0": [8.99, 6.36, 5.03, 4.21, 3.6, 3.23, 2.80],
|
||||
"shift 1.5": [6.08, 4.26, 3.43, 2.99, 2.73, 2.54, 2.33],
|
||||
"shift 2.0": [5.57, 3.81, 3.11, 2.75, 2.54, 2.43, 2.26],
|
||||
"shift 3.0": [7.26, 4.48, 3.43, 2.97, 2.72, 2.57, 2.38],
|
||||
}
|
||||
# plt.rc('axes.spines', **{'bottom':True, 'left':True, 'right':False, 'top':False})
|
||||
|
||||
# colors = ["#ff99c8", "#fcf6bd", "#d0f4de", "#a9def9"]
|
||||
|
||||
colors = ["#52b69a", "#34a0a4", "#168aad", "#1a759f"]
|
||||
steps = [5, 6, 7, 8, 9, 10, 12]
|
||||
for i ,(k, v)in enumerate(data.items()):
|
||||
plt.plot(steps, v, color=colors[i], label=k, linewidth=4, marker="o")
|
||||
|
||||
# plt.title("FID50K of different steps of different timeshift", weight="bold")
|
||||
plt.ylabel("FID50K", weight="bold")
|
||||
plt.xlabel("Num of inference steps", weight="bold")
|
||||
plt.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
||||
# plt.legend()
|
||||
# plt.legend()
|
||||
plt.savefig("output/timeshift_fid.pdf", bbox_inches="tight", pad_inches=0)
|
||||
Reference in New Issue
Block a user