app demo
This commit is contained in:
34
app.py
34
app.py
@@ -41,6 +41,7 @@ from src.diffusion.stateful_flow_matching.sharing_sampling import EulerSampler
|
|||||||
from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
|
from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
def instantiate_class(config):
|
def instantiate_class(config):
|
||||||
@@ -79,29 +80,32 @@ class Pipeline:
|
|||||||
generator = torch.Generator(device="cuda").manual_seed(seed)
|
generator = torch.Generator(device="cuda").manual_seed(seed)
|
||||||
xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator)
|
xT = torch.randn((num_images, 4, 32, 32), device="cuda", dtype=torch.float32, generator=generator)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
condition, uncondition = conditioner(y)
|
condition, uncondition = conditioner([y,]*num_images)
|
||||||
# Sample images:
|
# Sample images:
|
||||||
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
samples = diffusion_sampler(denoiser, xT, condition, uncondition)
|
||||||
samples = vae.decode(samples)
|
samples = vae.decode(samples)
|
||||||
# fp32 -1,1 -> uint8 0,255
|
# fp32 -1,1 -> uint8 0,255
|
||||||
samples = fp2uint8(samples)
|
samples = fp2uint8(samples)
|
||||||
|
samples = samples.permute(0, 2, 3, 1).cpu().numpy()
|
||||||
images = []
|
images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
image = Image.fromarray(samples[i].cpu().numpy())
|
image = Image.fromarray(samples[i])
|
||||||
images.append(image)
|
images.append(image)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--config", type=str, default=None)
|
parser.add_argument("--config", type=str, default="configs/repa_improved_ddt_xlen22de6_256.yaml")
|
||||||
parser.add_argument("--resolution", type=int, default=256)
|
parser.add_argument("--resolution", type=int, default=256)
|
||||||
parser.add_argument("--ckpt_path", type=str, default="")
|
parser.add_argument("--model_id", type=str, default="MCG-NJU/DDT-XL-22en6de-R256")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, default="models")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not os.path.exists(args.ckpt_path):
|
||||||
|
snapshot_download(repo_id=args.model_id, local_dir=args.ckpt_path)
|
||||||
|
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
vae_config = config.model.vae
|
vae_config = config.model.vae
|
||||||
diffusion_sampler_config = config.model.diffusion_sampler
|
diffusion_sampler_config = config.model.diffusion_sampler
|
||||||
@@ -124,6 +128,7 @@ if __name__ == "__main__":
|
|||||||
guidance_interval_max=1.0,
|
guidance_interval_max=1.0,
|
||||||
timeshift=1.0
|
timeshift=1.0
|
||||||
)
|
)
|
||||||
|
ckpt_path = os.path.join(args.ckpt_path, "model.ckpt")
|
||||||
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
ckpt = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
denoiser = load_model(ckpt, denoiser)
|
denoiser = load_model(ckpt, denoiser)
|
||||||
denoiser = denoiser.cuda()
|
denoiser = denoiser.cuda()
|
||||||
@@ -135,16 +140,17 @@ if __name__ == "__main__":
|
|||||||
with gr.Blocks() as demo:
|
with gr.Blocks() as demo:
|
||||||
gr.Markdown("DDT")
|
gr.Markdown("DDT")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=50)
|
||||||
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=3.0)
|
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
||||||
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=1)
|
num_images = gr.Slider(minimum=1, maximum=10, step=1, label="num images", value=8)
|
||||||
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=1000)
|
label = gr.Slider(minimum=0, maximum=999, step=1, label="label", value=948)
|
||||||
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
||||||
with gr.Row():
|
|
||||||
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
state_refresh_rate = gr.Slider(minimum=1, maximum=10, step=1, label="encoder reuse", value=1)
|
||||||
guidance_interval_min = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance min", value=0.3)
|
guidance_interval_min = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="interval guidance min", value=0.0)
|
||||||
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
|
guidance_interval_max = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="interval guidance max", value=1.0)
|
||||||
timeshift = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="timeshift", value=1.0)
|
timeshift = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="timeshift", value=1.0)
|
||||||
|
with gr.Column(scale=2):
|
||||||
btn = gr.Button("Generate")
|
btn = gr.Button("Generate")
|
||||||
output = gr.Gallery(label="Images")
|
output = gr.Gallery(label="Images")
|
||||||
|
|
||||||
@@ -160,4 +166,4 @@ if __name__ == "__main__":
|
|||||||
guidance_interval_max,
|
guidance_interval_max,
|
||||||
timeshift
|
timeshift
|
||||||
], outputs=[output])
|
], outputs=[output])
|
||||||
demo.launch()
|
demo.launch(server_name="0.0.0.0", server_port=7861)
|
||||||
Reference in New Issue
Block a user