app demo
This commit is contained in:
11
app.py
11
app.py
@@ -95,8 +95,6 @@ class Pipeline:
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
|
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
|
||||||
@@ -131,6 +129,7 @@ if __name__ == "__main__":
|
|||||||
denoiser = denoiser.cuda()
|
denoiser = denoiser.cuda()
|
||||||
vae = vae.cuda()
|
vae = vae.cuda()
|
||||||
denoiser.eval()
|
denoiser.eval()
|
||||||
|
|
||||||
# read imagenet classlabels
|
# read imagenet classlabels
|
||||||
with open("imagenet_classlabels.txt", "r") as f:
|
with open("imagenet_classlabels.txt", "r") as f:
|
||||||
classlabels = f.readlines()
|
classlabels = f.readlines()
|
||||||
@@ -147,12 +146,14 @@ if __name__ == "__main__":
|
|||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
||||||
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
||||||
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4)
|
||||||
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
||||||
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
||||||
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
||||||
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min",
|
||||||
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
|
value=0.0)
|
||||||
|
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max",
|
||||||
|
value=1.0)
|
||||||
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
btn = gr.Button("Generate")
|
btn = gr.Button("Generate")
|
||||||
|
|||||||
Reference in New Issue
Block a user