This commit is contained in:
wangshuai6
2025-04-11 19:10:11 +08:00
parent 0972ab65b4
commit c1c8043ed1
2 changed files with 1012 additions and 4 deletions

16
app.py
View File

@@ -62,12 +62,13 @@ def load_model(weight_dict, denosier):
class Pipeline:
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution, classlabels2ids):
self.vae = vae
self.denoiser = denoiser
self.conditioner = conditioner
self.diffusion_sampler = diffusion_sampler
self.resolution = resolution
self.classlabels2ids = classlabels2ids
@torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@@ -81,7 +82,7 @@ class Pipeline:
generator = torch.Generator(device="cuda").manual_seed(seed)
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
with torch.no_grad():
condition, uncondition = conditioner([y,]*num_images)
condition, uncondition = conditioner([self.classlabels2ids[y],]*num_images)
# Sample images:
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
samples = vae.decode(samples)
@@ -130,8 +131,15 @@ if __name__ == "__main__":
denoiser = denoiser.cuda()
vae = vae.cuda()
denoiser.eval()
# read imagenet classlabels
with open("imagenet_classlabels.txt", "r") as f:
classlabels = f.readlines()
classlabels = [label.strip() for label in classlabels]
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
classlabels2id = {label: i for i, label in enumerate(classlabels)}
id2classlabels = {i: label for i, label in enumerate(classlabels)}
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution, classlabels2id)
with gr.Blocks() as demo:
gr.Markdown("DDT")
@@ -140,7 +148,7 @@ if __name__ == "__main__":
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
label = gr.Dropdown(choices=classlabels, value=id2classlabels[948], label="label")
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)

1000
imagenet_classlabels.txt Normal file

File diff suppressed because it is too large Load Diff