app demo
This commit is contained in:
18
app.py
18
app.py
@@ -62,11 +62,12 @@ def load_model(weight_dict, denosier):
|
|||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, vae, denoiser, conditioner, diffusion_sampler):
|
def __init__(self, vae, denoiser, conditioner, diffusion_sampler, resolution):
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.denoiser = denoiser
|
self.denoiser = denoiser
|
||||||
self.conditioner = conditioner
|
self.conditioner = conditioner
|
||||||
self.diffusion_sampler = diffusion_sampler
|
self.diffusion_sampler = diffusion_sampler
|
||||||
|
self.resolution = resolution
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||||
@@ -78,7 +79,7 @@ class Pipeline:
|
|||||||
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
|
self.diffusion_sampler.guidance_interval_max = guidance_interval_max
|
||||||
self.diffusion_sampler.timeshift = timeshift
|
self.diffusion_sampler.timeshift = timeshift
|
||||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||||
xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator)
|
xT = torch.randn((num_images, 4, self.resolution//8, self.resolution//8), device="cuda", dtype=torch.float32, generator=generator)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
condition, uncondition = conditioner([y,]*num_images)
|
condition, uncondition = conditioner([y,]*num_images)
|
||||||
# Sample images:
|
# Sample images:
|
||||||
@@ -97,15 +98,11 @@ import os
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml")
|
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_512.yaml")
|
||||||
parser.add_argument("--resolution", type=int, default=256)
|
parser.add_argument("--resolution", type=int, default=512)
|
||||||
parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256")
|
|
||||||
parser.add_argument("--ckpt_path", type=str, default="models")
|
parser.add_argument("--ckpt_path", type=str, default="models")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not os.path.exists(args.ckpt_path):
|
|
||||||
snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
|
|
||||||
|
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
vae_config = config.model.vae
|
vae_config = config.model.vae
|
||||||
diffusion_sampler_config = config.model.diffusion_sampler
|
diffusion_sampler_config = config.model.diffusion_sampler
|
||||||
@@ -128,14 +125,13 @@ if __name__ == "__main__":
|
|||||||
guidance_interval_max=1.0,
|
guidance_interval_max=1.0,
|
||||||
timeshift=1.0
|
timeshift=1.0
|
||||||
)
|
)
|
||||||
ckpt_path = os.path.join(args.ckpt_path, "model.ckpt")
|
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
||||||
denoiser = load_model(ckpt, denoiser)
|
denoiser = load_model(ckpt, denoiser)
|
||||||
denoiser = denoiser.cuda()
|
denoiser = denoiser.cuda()
|
||||||
vae = vae.cuda()
|
vae = vae.cuda()
|
||||||
denoiser.eval()
|
denoiser.eval()
|
||||||
|
|
||||||
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler)
|
pipeline = Pipeline(vae, denoiser, conditioner, diffusion_sampler, args.resolution)
|
||||||
|
|
||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("DDT")
|
gr.Markdown("DDT")
|
||||||
|
|||||||
Reference in New Issue
Block a user