File size: 4,747 Bytes
0f0e0b1
b2d15ad
4a3ba70
 
0f0e0b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a41be5c
0f0e0b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from __future__ import annotations
import os
os.system("pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers")
os.system("pip install git+https://github.com/alvanli/RDM-Region-Aware-Diffusion-Model.git@main#egg=guided_diffusion")
import math
import random

import gradio as gr
import torch
from PIL import Image, ImageOps
from run_edit import run_model
from cool_models import make_models

help_text = """"""

def main():
    segmodel, model, diffusion, ldm, bert, clip_model, model_params = make_models()

    def generate(
        input_image: Image.Image,
        from_text: str,
        instruction: str,
        negative_prompt: str,
        randomize_seed: bool,
        seed: int,
        guidance_scale: float,
        clip_guidance_scale: float,
        cutn: int,
        l2_sim_lambda: float
    ):
        seed = random.randint(0, 100000) if randomize_seed else seed

        if instruction == "":
            return [seed, input_image]

        generator = torch.manual_seed(seed)

        edited_image_1 = run_model(
            segmodel, model, diffusion, ldm, bert, clip_model, model_params,
            from_text, instruction, negative_prompt, input_image.convert('RGB'), seed, guidance_scale, clip_guidance_scale, cutn, l2_sim_lambda
        )

        return [seed, edited_image_1]

    def reset():
        return [
            "Randomize Seed", 1371, None, 5.0,
            150, 16, 10000
        ]

    with gr.Blocks() as demo:
        gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">
   RDM: Region-Aware Diffusion for Zero-shot Text-driven Image Editing
</h1>
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
</p>""")
        with gr.Row():
            with gr.Column(scale=1, min_width=100):
                generate_button = gr.Button("Generate")
            # with gr.Column(scale=1, min_width=100):
            #     load_button = gr.Button("Load Example")
            with gr.Column(scale=1, min_width=100):
                reset_button = gr.Button("Reset")
            with gr.Column(scale=3):
                from_text = gr.Textbox(lines=1, label="From Text", interactive=True)
                instruction = gr.Textbox(lines=1, label="Edit Instruction", interactive=True)
                negative_prompt = gr.Textbox(lines=1, label="Negative Prompt", interactive=True)

        with gr.Row():
            input_image = gr.Image(label="Input Image", type="pil", interactive=True)
            edited_image_1 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
            # edited_image_2 = gr.Image(label=f"Edited Image", type="pil", interactive=False)
            input_image.style(height=512, width=512)
            edited_image_1.style(height=512, width=512)
            # edited_image_2.style(height=512, width=512)

        with gr.Row():
            # steps = gr.Number(value=50, precision=0, label="Steps", interactive=True)
            seed = gr.Number(value=1371, precision=0, label="Seed", interactive=True)
            guidance_scale = gr.Number(value=5.0, precision=1, label="Guidance Scale", interactive=True)
            clip_guidance_scale = gr.Number(value=150, precision=1, label="Clip Guidance Scale", interactive=True)
            cutn = gr.Number(value=16, precision=1, label="Number of Cuts", interactive=True)
            l2_sim_lambda = gr.Number(value=10000, precision=1, label="L2 similarity to original image")

            randomize_seed = gr.Radio(
                ["Fix Seed", "Randomize Seed"],
                value="Randomize Seed",
                type="index",
                show_label=False,
                interactive=True,
            )
            # use_ddim = gr.Checkbox(label="Use 50-step DDIM?", value=True)
            # use_ddpm = gr.Checkbox(label="Use 50-step DDPM?", value=True)
        
        gr.Markdown(help_text)

        generate_button.click(
            fn=generate,
            inputs=[
                input_image,
                from_text,
                instruction,
                negative_prompt,
                randomize_seed,
                seed,
                guidance_scale,
                clip_guidance_scale,
                cutn,
                l2_sim_lambda
            ],
            outputs=[seed, edited_image_1],
        )
        reset_button.click(
            fn=reset,
            inputs=[],
            outputs=[
                randomize_seed, seed, edited_image_1, guidance_scale,
                clip_guidance_scale, cutn, l2_sim_lambda
            ],
        )

    demo.queue(concurrency_count=1)
    demo.launch(share=False, server_name="0.0.0.0")


if __name__ == "__main__":
    main()