This commit is contained in:
wangshuai6
2025-04-11 18:30:34 +08:00
parent d961ccc57a
commit 0972ab65b4

18
app.py
View File

@@ -62,11 +62,12 @@ def load_model(weight_dict, denosier):
class Pipeline: class Pipeline:
def __init__(self, vae, denoiser, conditioner, diffusion_sampler): def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
self.vae = vae self.vae = vae
self.denoiser = denoiser self.denoiser = denoiser
self.conditioner = conditioner self.conditioner = conditioner
self.diffusion_sampler = diffusion_sampler self.diffusion_sampler = diffusion_sampler
self.resolution = resolution
@torch.no_grad() @torch.no_grad()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16) @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@@ -78,7 +79,7 @@ class Pipeline:
self.diffusion_sampler.guidance_interval_max = guidance_interval_max self.diffusion_sampler.guidance_interval_max = guidance_interval_max
self.diffusion_sampler.timeshift = timeshift self.diffusion_sampler.timeshift = timeshift
generator = torch.Generator(device="cuda").manual_seed(seed) generator = torch.Generator(device="cuda").manual_seed(seed)
xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator) xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
with torch.no_grad(): with torch.no_grad():
condition, uncondition = conditioner([y,]*num_images) condition, uncondition = conditioner([y,]*num_images)
# Sample images: # Sample images:
@@ -97,15 +98,11 @@ import os
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml") parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
parser.add_argument("--resolution", type=int, default=256) parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256")
parser.add_argument("--ckpt_path", type=str, default="models") parser.add_argument("--ckpt_path", type=str, default="models")
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.ckpt_path):
snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
vae_config = config.model.vae vae_config = config.model.vae
diffusion_sampler_config = config.model.diffusion_sampler diffusion_sampler_config = config.model.diffusion_sampler
@@ -128,14 +125,13 @@ if __name__ == "__main__":
guidance_interval_max=1.0, guidance_interval_max=1.0,
timeshift=1.0 timeshift=1.0
) )
ckpt_path = os.path.join(args.ckpt_path, "model.ckpt") ckpt = torch.load(args.ckpt_path, map_location="cpu")
ckpt = torch.load(ckpt_path, map_location="cpu")
denoiser = load_model(ckpt, denoiser) denoiser = load_model(ckpt, denoiser)
denoiser = denoiser.cuda() denoiser = denoiser.cuda()
vae = vae.cuda() vae = vae.cuda()
denoiser.eval() denoiser.eval()
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler) pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("DDT") gr.Markdown("DDT")