This commit is contained in:
wangshuai6
2025-04-11 19:40:27 +08:00
parent 4591692c30
commit 28d9af791d

6
app.py
View File

@@ -79,8 +79,10 @@ class Pipeline:
self.diffusion_sampler.guidance_interval_min = guidance_interval_min
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
self.diffusion_sampler.timeshift = timeshift
generator = torch.Generator(device="cuda").manual_seed(seed)
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
generator = torch.Generator(device="cpu").manual_seed(seed)
xT = torch.randn((num_images, 4, self.resolution // 8, self.resolution // 8), device="cpu", dtype=torch.float32,
generator=generator)
xT = xT.to("cuda")
with torch.no_grad():
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
# Sample images: