From 12227c4f7f1a13474645e2e8abcbcb41ac78ec2d Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Fri, 11 Apr 2025 17:41:00 +0800 Subject: [PATCH] app demo --- app.py | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index ad040f8..63013c6 100644 --- a/app.py +++ b/app.py @@ -41,6 +41,7 @@ from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler from PIL import Image import gradio as gr +from huggingface_hub import snapshot_download def instantiate_class(config): @@ -79,29 +80,32 @@ class Pipeline: generator = torch.Generator(device="cuda").manual_seed(seed) xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator) with torch.no_grad(): - condition, uncondition = conditioner(y) + condition, uncondition = conditioner([y,]*num_images) # Sample images: samples = diffusion_sampler(denoiser, xT, condition, uncondition) samples = vae.decode(samples) # fp32 -1,1 -> uint8 0,255 samples = fp2uint8(samples) + samples = samples.permute(0, 2, 3, 1).cpu().numpy() images = [] for i in range(num_images): - image = Image.fromarray(samples[i].cpu().numpy()) + image = Image.fromarray(samples[i]) images.append(image) return images - - - +import os if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default=None) + parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml") parser.add_argument("--resolution", type=int, default=256) - parser.add_argument("--ckpt_path", type=str, default="") + parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256") + parser.add_argument("--ckpt_path", type=str, default="models") args = parser.parse_args() + if not os.path.exists(args.ckpt_path): + snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path) + config = OmegaConf.load(args.config) vae_config = config.model.vae diffusion_sampler_config = config.model.diffusion_sampler @@ -124,6 +128,7 @@ if __name__ == "__main__": guidance_interval_max=1.0, timeshift=1.0 ) + ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") ckpt = torch.load(args.ckpt_path, map_location="cpu") denoiser = load_model(ckpt, denoiser) denoiser = denoiser.cuda() @@ -135,18 +140,19 @@ if __name__ == "__main__": with gr.Blocks() as demo: gr.Markdown("DDT") with gr.Row(): - num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) - guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=3.0) - num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=1) - label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=1000) - seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) - with gr.Row(): - state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) - guidance_interval_min = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance min", value=0.3) - guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) - timeshift = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="timeshift", value=1.0) - btn = gr.Button("Generate") - output = gr.Gallery(label="Images") + with gr.Column(scale=1): + num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) + guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0) + num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8) + label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948) + seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) + state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) + guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0) + guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) + timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0) + with gr.Column(scale=2): + btn = gr.Button("Generate") + output = gr.Gallery(label="Images") btn.click(fn=pipeline, inputs=[ @@ -160,4 +166,4 @@ if __name__ == "__main__": guidance_interval_max, timeshift ], outputs=[output]) - demo.launch() \ No newline at end of file + demo.launch(server_name="0.0.0.0", server_port=7861) \ No newline at end of file