From 0972ab65b4318ace40b9345f6cc5d92c4c89d3e6 Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Fri, 11 Apr 2025 18:30:34 +0800 Subject: [PATCH] app demo --- app.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/app.py b/app.py index 41b1de8..8166b79 100644 --- a/app.py +++ b/app.py @@ -62,11 +62,12 @@ def load_model(weight_dict, denosier): class Pipeline: - def __init__(self, vae, denoiser, conditioner, diffusion_sampler): + def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution): self.vae = vae self.denoiser = denoiser self.conditioner = conditioner self.diffusion_sampler = diffusion_sampler + self.resolution = resolution @torch.no_grad() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) @@ -78,7 +79,7 @@ class Pipeline: self.diffusion_sampler.guidance_interval_max = guidance_interval_max self.diffusion_sampler.timeshift = timeshift generator = torch.Generator(device="cuda").manual_seed(seed) - xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator) + xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator) with torch.no_grad(): condition, uncondition = conditioner([y,]*num_images) # Sample images: @@ -97,15 +98,11 @@ import os if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml") - parser.add_argument("--resolution", type=int, default=256) - parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256") + parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml") + parser.add_argument("--resolution", type=int, default=512) parser.add_argument("--ckpt_path", type=str, default="models") args = parser.parse_args() - if not os.path.exists(args.ckpt_path): - snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path) - config = OmegaConf.load(args.config) vae_config = config.model.vae diffusion_sampler_config = config.model.diffusion_sampler @@ -128,14 +125,13 @@ if __name__ == "__main__": guidance_interval_max=1.0, timeshift=1.0 ) - ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") - ckpt = torch.load(ckpt_path, map_location="cpu") + ckpt = torch.load(args.ckpt_path, map_location="cpu") denoiser = load_model(ckpt, denoiser) denoiser = denoiser.cuda() vae = vae.cuda() denoiser.eval() - pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler) + pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution) with gr.Blocks() as demo: gr.Markdown("DDT")