From 28d9af791df31106167ec76d5c51bbaf9782766f Mon Sep 17 00:00:00 2001 From: wangshuai6 Date: Fri, 11 Apr 2025 19:40:27 +0800 Subject: [PATCH] app demo --- app.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index fe25e9d..5cd2ce0 100644 --- a/app.py +++ b/app.py @@ -79,8 +79,10 @@ class Pipeline: self.diffusion_sampler.guidance_interval_min = guidance_interval_min self.diffusion_sampler.guidance_interval_max = guidance_interval_max self.diffusion_sampler.timeshift = timeshift - generator = torch.Generator(device="cuda").manual_seed(seed) - xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator) + generator = torch.Generator(device="cpu").manual_seed(seed) + xT = torch.randn((num_images, 4, self.resolution // 8, self.resolution // 8), device="cpu", dtype=torch.float32, + generator=generator) + xT = xT.to("cuda") with torch.no_grad(): condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images) # Sample images: