95 lines
2.4 KiB
Python
95 lines
2.4 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
data = {
|
|
"SiT-XL/2" : {
|
|
"size": 675,
|
|
"epochs": 1400,
|
|
"FID": 2.06,
|
|
"color": "#ff99c8"
|
|
},
|
|
"DiT-XL/2" : {
|
|
"size": 675,
|
|
"epochs": 1400,
|
|
"FID": 2.27,
|
|
"color": "#fcf6bd"
|
|
},
|
|
"REPA-XL/2" : {
|
|
"size": 675,
|
|
"epochs": 800,
|
|
"FID": 1.42,
|
|
"color": "#d0f4de"
|
|
},
|
|
# "MAR-H" : {
|
|
# "size": 973,
|
|
# "epochs": 800,
|
|
# "FID": 1.55,
|
|
# },
|
|
"MDTv2" : {
|
|
"size": 675,
|
|
"epochs": 920,
|
|
"FID": 1.58,
|
|
"color": "#e4c1f9"
|
|
},
|
|
# "VAVAE+LightningDiT" : {
|
|
# "size": 675,
|
|
# "epochs": [64, 800],
|
|
# "FID": [2.11, 1.35],
|
|
# },
|
|
"DDT-XL/2": {
|
|
"size": 675,
|
|
"epochs": [80, 256],
|
|
"FID": [1.52, 1.31],
|
|
"color": "#38a3a5"
|
|
},
|
|
"DDT-L/2": {
|
|
"size": 400,
|
|
"epochs": 80,
|
|
"FID": 1.64,
|
|
"color": "#5bc0be"
|
|
},
|
|
}
|
|
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(1, 1, 1)
|
|
for k, spec in data.items():
|
|
plt.scatter(
|
|
# spec["size"],
|
|
spec["epochs"],
|
|
spec["FID"],
|
|
label=k,
|
|
marker="o",
|
|
s=spec["size"],
|
|
color=spec["color"],
|
|
)
|
|
x = spec["epochs"]
|
|
y = spec["FID"]
|
|
if isinstance(spec["FID"], list):
|
|
x = spec["epochs"][-1]
|
|
y = spec["FID"][-1]
|
|
plt.plot(
|
|
spec["epochs"],
|
|
spec["FID"],
|
|
color=spec["color"],
|
|
linestyle="dotted",
|
|
linewidth=4
|
|
)
|
|
# plt.annotate("",
|
|
# xytext=(spec["epochs"][0], spec["FID"][0]),
|
|
# xy=(spec["epochs"][1], spec["FID"][1]), arrowprops=dict(arrowstyle="--"), weight="bold")
|
|
plt.text(x+80, y-0.05, k, fontsize=13)
|
|
|
|
plt.text(200, 1.45, "4x Training Acc", fontsize=12, color="#38a3a5", weight="bold")
|
|
# plt.arrow(200, 1.42, 520, 0, linewidth=2, fc='black', ec='black', hatch="x", head_width=0.05, head_length=0.05)
|
|
|
|
plt.annotate("",
|
|
xy=(700, 1.42), xytext=(200, 1.42),
|
|
arrowprops=dict(arrowstyle='<->', color='black', linewidth=2),
|
|
)
|
|
ax.grid(linestyle="-.", alpha=0.6, linewidth=0.5)
|
|
plt.gca().set_xlim(0, 1800)
|
|
plt.gca().set_ylim(1.15, 2.5)
|
|
plt.xticks([80, 256, 800, 1000, 1200, 1400, 1600, ])
|
|
plt.xlabel("Training Epochs", weight="bold")
|
|
plt.ylabel("FID50K on ImageNet256x256", weight="bold")
|
|
plt.savefig("output/sota.pdf", bbox_inches="tight") |