DMM / app.py
songtianhui
add model details
74ec624
raw
history blame
3.37 kB
import random
import spaces
import torch
import gradio as gr
from modeling.dmm_pipeline import StableDiffusionDMMPipeline
from huggingface_hub import snapshot_download
ckpt_path = "ckpt"
snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path)
pipe = StableDiffusionDMMPipeline.from_pretrained(
ckpt_path,
torch_dtype=torch.float16,
use_safetensors=True
)
pipe.to("cuda")
@spaces.GPU
def generate(prompt: str,
negative_prompt: str,
model_id: int,
seed: int = 1234,
height: int = 512,
width: int = 512,
all: bool = True):
if all:
outputs = []
for i in range(pipe.unet.get_num_models()):
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=25,
guidance_scale=7,
model_id=i,
generator=torch.Generator().manual_seed(seed),
).images[0]
outputs.append(output)
return outputs
else:
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=25,
guidance_scale=7,
model_id=int(model_id),
generator=torch.Generator().manual_seed(seed),
).images[0]
return [output,]
candidates = [
"0. [JuggernautReborn] realistic",
"1. [MajicmixRealisticV7] realistic, Asia portrait",
"2. [EpicRealismV5] realistic",
"3. [RealisticVisionV5] realistic",
"4. [MajicmixFantasyV3] animation",
"5. [MinimalismV2] illustration",
"6. [RealCartoon3dV17] cartoon 3d",
"7. [AWPaintingV1.4] animation",
]
def main():
with gr.Blocks() as demo:
gr.Markdown(
"""
# DMM Demo
The checkpoint is https://huggingface.co/MCG-NJU/DMM.
"""
)
with gr.Row():
with gr.Column():
with gr.Column():
model_id = gr.Dropdown(candidates, label="Model Index", type="index")
all_check = gr.Checkbox(label="All (ignore the selection above)")
prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt")
negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt")
with gr.Row():
seed = gr.Number(0, label="Seed", precision=0, scale=3)
update_seed_btn = gr.Button("🎲", scale=1)
with gr.Row():
height = gr.Number(768, step=8, label="Height (suggest 512~768)")
width = gr.Number(512, step=8, label="Width")
submit_btn = gr.Button("Submit", variant="primary")
output = gr.Gallery(label="images")
submit_btn.click(generate,
inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check],
outputs=[output])
update_seed_btn.click(lambda: random.randint(0, 1000000),
outputs=[seed])
demo.launch()
if __name__ == "__main__":
main()