app demo
This commit is contained in:
16
app.py
16
app.py
@@ -62,12 +62,13 @@ def load_model(weight_dict, denosier):
|
|||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.denoiser = denoiser
|
self.denoiser = denoiser
|
||||||
self.conditioner = conditioner
|
self.conditioner = conditioner
|
||||||
self.diffusion_sampler = diffusion_sampler
|
self.diffusion_sampler = diffusion_sampler
|
||||||
self.resolution = resolution
|
self.resolution = resolution
|
||||||
|
self.classlabels2ids = classlabels2ids
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||||
@@ -81,7 +82,7 @@ class Pipeline:
|
|||||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||||
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
condition, uncondition = conditioner([y,]*num_images)
|
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
||||||
# Sample images:
|
# Sample images:
|
||||||
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
||||||
samples = vae.decode(samples)
|
samples = vae.decode(samples)
|
||||||
@@ -130,8 +131,15 @@ if __name__ == "__main__":
|
|||||||
denoiser = denoiser.cuda()
|
denoiser = denoiser.cuda()
|
||||||
vae = vae.cuda()
|
vae = vae.cuda()
|
||||||
denoiser.eval()
|
denoiser.eval()
|
||||||
|
# read imagenet classlabels
|
||||||
|
with open("imagenet_classlabels.txt", "r") as f:
|
||||||
|
classlabels = f.readlines()
|
||||||
|
classlabels = [label.strip() for label in classlabels]
|
||||||
|
|
||||||
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
|
classlabels2id = {label: i for i, label in enumerate(classlabels)}
|
||||||
|
id2classlabels = {i: label for i, label in enumerate(classlabels)}
|
||||||
|
|
||||||
|
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("DDT")
|
gr.Markdown("DDT")
|
||||||
@@ -140,7 +148,7 @@ if __name__ == "__main__":
|
|||||||
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
||||||
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
||||||
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
||||||
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
|
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
||||||
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
||||||
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
||||||
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
||||||
|
|||||||
1000
imagenet_classlabels.txt
Normal file
1000
imagenet_classlabels.txt
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user