This commit is contained in:
wangshuai6
2025-04-11 19:31:52 +08:00
parent 687d650175
commit 4591692c30

11
app.py
View File

@@ -95,8 +95,6 @@ class Pipeline:
images.append(image) images.append(image)
return images return images
import os
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml") parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
@@ -131,6 +129,7 @@ if __name__ == "__main__":
denoiser = denoiser.cuda() denoiser = denoiser.cuda()
vae = vae.cuda() vae = vae.cuda()
denoiser.eval() denoiser.eval()
# read imagenet classlabels # read imagenet classlabels
with open("imagenet_classlabels.txt", "r") as f: with open("imagenet_classlabels.txt", "r") as f:
classlabels = f.readlines() classlabels = f.readlines()
@@ -147,12 +146,14 @@ if __name__ == "__main__":
with gr.Column(scale=1): with gr.Column(scale=1):
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8) num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4)
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label") label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0) guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) value=0.0)
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max",
value=1.0)
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0) timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
with gr.Column(scale=2): with gr.Column(scale=2):
btn = gr.Button("Generate") btn = gr.Button("Generate")