This commit is contained in:
wangshuai6
2025-04-11 17:41:00 +08:00
parent ed3c466d08
commit 12227c4f7f

46
app.py
View File

@@ -41,6 +41,7 @@ from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler
from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
from PIL import Image from PIL import Image
import gradio as gr import gradio as gr
from huggingface_hub import snapshot_download
def instantiate_class(config): def instantiate_class(config):
@@ -79,29 +80,32 @@ class Pipeline:
generator = torch.Generator(device="cuda").manual_seed(seed) generator = torch.Generator(device="cuda").manual_seed(seed)
xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator) xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator)
with torch.no_grad(): with torch.no_grad():
condition, uncondition = conditioner(y) condition, uncondition = conditioner([y,]*num_images)
# Sample images: # Sample images:
samples = diffusion_sampler(denoiser, xT, condition, uncondition) samples = diffusion_sampler(denoiser, xT, condition, uncondition)
samples = vae.decode(samples) samples = vae.decode(samples)
# fp32 -1,1 -> uint8 0,255 # fp32 -1,1 -> uint8 0,255
samples = fp2uint8(samples) samples = fp2uint8(samples)
samples = samples.permute(0, 2, 3, 1).cpu().numpy()
images = [] images = []
for i in range(num_images): for i in range(num_images):
image = Image.fromarray(samples[i].cpu().numpy()) image = Image.fromarray(samples[i])
images.append(image) images.append(image)
return images return images
import os
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None) parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml")
parser.add_argument("--resolution", type=int, default=256) parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--ckpt_path", type=str, default="") parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256")
parser.add_argument("--ckpt_path", type=str, default="models")
args = parser.parse_args() args = parser.parse_args()
if not os.path.exists(args.ckpt_path):
snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
vae_config = config.model.vae vae_config = config.model.vae
diffusion_sampler_config = config.model.diffusion_sampler diffusion_sampler_config = config.model.diffusion_sampler
@@ -124,6 +128,7 @@ if __name__ == "__main__":
guidance_interval_max=1.0, guidance_interval_max=1.0,
timeshift=1.0 timeshift=1.0
) )
ckpt_path = os.path.join(args.ckpt_path, "model.ckpt")
ckpt = torch.load(args.ckpt_path, map_location="cpu") ckpt = torch.load(args.ckpt_path, map_location="cpu")
denoiser = load_model(ckpt, denoiser) denoiser = load_model(ckpt, denoiser)
denoiser = denoiser.cuda() denoiser = denoiser.cuda()
@@ -135,18 +140,19 @@ if __name__ == "__main__":
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("DDT") gr.Markdown("DDT")
with gr.Row(): with gr.Row():
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50) with gr.Column(scale=1):
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=3.0) num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=1) guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=1000) num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0) label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
with gr.Row(): seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1) state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
guidance_interval_min = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance min", value=0.3) guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0) guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
timeshift = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="timeshift", value=1.0) timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
btn = gr.Button("Generate") with gr.Column(scale=2):
output = gr.Gallery(label="Images") btn = gr.Button("Generate")
output = gr.Gallery(label="Images")
btn.click(fn=pipeline, btn.click(fn=pipeline,
inputs=[ inputs=[
@@ -160,4 +166,4 @@ if __name__ == "__main__":
guidance_interval_max, guidance_interval_max,
timeshift timeshift
], outputs=[output]) ], outputs=[output])
demo.launch() demo.launch(server_name="0.0.0.0", server_port=7861)