diff --git a/app.py b/app.py index 317da4c..fe25e9d 100644 --- a/app.py +++ b/app.py @@ -95,8 +95,6 @@ class Pipeline: images.append(image) return images -import os - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml") @@ -131,6 +129,7 @@ if __name__ == "__main__": denoiser = denoiser.cuda() vae = vae.cuda() denoiser.eval() + # read imagenet classlabels with open("imagenet_classlabels.txt", "r") as f: classlabels = f.readlines() @@ -147,12 +146,14 @@ if __name__ == "__main__": with gr.Column(scale=1): num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) - num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8) + num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=4) label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label") seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) - guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0) - guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) + guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", + value=0.0) + guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", + value=1.0) timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0) with gr.Column(scale=2): btn = gr.Button("Generate")