18 lines
571 B
Python
18 lines
571 B
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
t = np.linspace(0.001, 0.999, 100)
|
|
def snr(t):
|
|
return np.log((1-t)/t)
|
|
def pds(t):
|
|
return np.clip(((1-t)/t)**2, a_max=0.5, a_min=0.0)
|
|
print(pds(t))
|
|
plt.figure(figsize=(16, 4))
|
|
plt.plot(t, snr(t), color="#ff70a6", linewidth=3, marker="o")
|
|
# plt.plot(t, pds(t), color="#ff9770", linewidth=3, marker="o")
|
|
plt.ylabel("log-SNR", weight="bold")
|
|
plt.xlabel("Timesteps", weight="bold")
|
|
plt.xticks([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])
|
|
plt.gca().invert_xaxis()
|
|
plt.show()
|
|
# plt.savefig("output/logsnr.pdf", bbox_inches='tight') |