app demo
This commit is contained in:
16
app.py
16
app.py
@@ -62,12 +62,13 @@ def load_model(weight_dict, denosier):
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
||||
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
|
||||
self.vae = vae
|
||||
self.denoiser = denoiser
|
||||
self.conditioner = conditioner
|
||||
self.diffusion_sampler = diffusion_sampler
|
||||
self.resolution = resolution
|
||||
self.classlabels2ids = classlabels2ids
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
@@ -81,7 +82,7 @@ class Pipeline:
|
||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
||||
with torch.no_grad():
|
||||
condition, uncondition = conditioner([y,]*num_images)
|
||||
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
|
||||
# Sample images:
|
||||
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
||||
samples = vae.decode(samples)
|
||||
@@ -130,8 +131,15 @@ if __name__ == "__main__":
|
||||
denoiser = denoiser.cuda()
|
||||
vae = vae.cuda()
|
||||
denoiser.eval()
|
||||
# read imagenet classlabels
|
||||
with open("imagenet_classlabels.txt", "r") as f:
|
||||
classlabels = f.readlines()
|
||||
classlabels = [label.strip() for label in classlabels]
|
||||
|
||||
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
|
||||
classlabels2id = {label: i for i, label in enumerate(classlabels)}
|
||||
id2classlabels = {i: label for i, label in enumerate(classlabels)}
|
||||
|
||||
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown("DDT")
|
||||
@@ -140,7 +148,7 @@ if __name__ == "__main__":
|
||||
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
||||
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
||||
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
||||
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
|
||||
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
|
||||
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
||||
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
||||
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
||||
|
||||
Reference in New Issue
Block a user