app demo
This commit is contained in:
6
app.py
6
app.py
@@ -79,8 +79,10 @@ class Pipeline:
|
|||||||
self.diffusion_sampler.guidance_interval_min = guidance_interval_min
|
self.diffusion_sampler.guidance_interval_min = guidance_interval_min
|
||||||
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
|
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
|
||||||
self.diffusion_sampler.timeshift = timeshift
|
self.diffusion_sampler.timeshift = timeshift
|
||||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||||
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
xT = torch.randn((num_images, 4, self.resolution // 8, self.resolution // 8), device="cpu", dtype=torch.float32,
|
||||||
|
generator=generator)
|
||||||
|
xT = xT.to("cuda")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
||||||
# Sample images:
|
# Sample images:
|
||||||
|
|||||||
Reference in New Issue
Block a user