File size: 3,372 Bytes
b15b2ae
439279b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15b2ae
439279b
 
 
 
 
b15b2ae
 
439279b
 
 
 
 
 
 
b15b2ae
 
439279b
 
 
 
 
 
 
 
 
 
 
b15b2ae
 
439279b
 
 
 
 
 
 
 
74ec624
 
 
 
 
 
 
 
 
 
439279b
 
 
74ec624
 
 
 
 
 
439279b
 
b15b2ae
74ec624
b15b2ae
439279b
 
b15b2ae
 
 
 
74ec624
b15b2ae
 
439279b
 
b15b2ae
 
 
 
 
439279b
74ec624
439279b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import random
import spaces
import torch
import gradio as gr

from modeling.dmm_pipeline import StableDiffusionDMMPipeline
from huggingface_hub import snapshot_download


ckpt_path = "ckpt"
snapshot_download(repo_id="MCG-NJU/DMM", local_dir=ckpt_path)

pipe = StableDiffusionDMMPipeline.from_pretrained(
    ckpt_path,
    torch_dtype=torch.float16, 
    use_safetensors=True
)
pipe.to("cuda")


@spaces.GPU
def generate(prompt: str,
             negative_prompt: str,
             model_id: int,
             seed: int = 1234,
             height: int = 512,
             width: int = 512,
             all: bool = True):
    if all:
        outputs = []
        for i in range(pipe.unet.get_num_models()):
            output = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                width=width,
                height=height,
                num_inference_steps=25,
                guidance_scale=7,
                model_id=i,
                generator=torch.Generator().manual_seed(seed),
            ).images[0]
            outputs.append(output)
        return outputs
    else:
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            num_inference_steps=25,
            guidance_scale=7,
            model_id=int(model_id),
            generator=torch.Generator().manual_seed(seed),
        ).images[0]
        return [output,]


candidates = [
    "0. [JuggernautReborn] realistic",
    "1. [MajicmixRealisticV7] realistic, Asia portrait",
    "2. [EpicRealismV5] realistic",
    "3. [RealisticVisionV5] realistic",
    "4. [MajicmixFantasyV3] animation",
    "5. [MinimalismV2] illustration",
    "6. [RealCartoon3dV17] cartoon 3d",
    "7. [AWPaintingV1.4] animation",
]

def main():
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # DMM Demo
            The checkpoint is https://huggingface.co/MCG-NJU/DMM.
            """
        )
        with gr.Row():
            with gr.Column():
                with gr.Column():
                    model_id = gr.Dropdown(candidates, label="Model Index", type="index")
                    all_check = gr.Checkbox(label="All (ignore the selection above)")
                prompt = gr.Textbox("portrait photo of a girl, long golden hair, flowers, best quality", label="Prompt")
                negative_prompt = gr.Textbox("worst quality,low quality,normal quality,lowres,watermark,nsfw", label="Negative Prompt")
                with gr.Row():
                    seed = gr.Number(0, label="Seed", precision=0, scale=3)
                    update_seed_btn = gr.Button("🎲", scale=1)
                with gr.Row():
                    height = gr.Number(768, step=8, label="Height (suggest 512~768)")
                    width = gr.Number(512, step=8, label="Width")
                submit_btn = gr.Button("Submit", variant="primary")
            output = gr.Gallery(label="images")
    
        submit_btn.click(generate,
                         inputs=[prompt, negative_prompt, model_id, seed, height, width, all_check],
                         outputs=[output])
        update_seed_btn.click(lambda: random.randint(0, 1000000),
                              outputs=[seed])

    demo.launch()


if __name__ == "__main__":
    main()