File size: 4,981 Bytes
5c4b5eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import time
import gradio as gr
import torch
import diffusers
from utils import patch_attention_proc
import math
import numpy as np
from PIL import Image

pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
pipe.enable_xformers_memory_efficient_attention()
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = None

with gr.Blocks() as demo:
    prompt = gr.Textbox(interactive=True, label="prompt")
    negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
    method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
    height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
    # height = gr.Number(label="height", value=1024, precision=0)
    # width = gr.Number(label="width", value=1024, precision=0)
    guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
    steps = gr.Number(label="steps", value=20, precision=0)
    seed = gr.Number(label="seed", value=1, precision=0)
    result = gr.Textbox(label="Result")

    output_image = gr.Image(label=f"output_image", type="pil", interactive=False)

    gen = gr.Button("generate")

    def which_image(img, target_val=253, width=1024):
        npimg = np.array(img)
        loc = np.where(npimg[:, :, 3] == target_val)[1].item()
        if loc > width:
            print("Right Image is merged!")
        else:
            print("Left Image is merged!")


    def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):

        pipe.enable_xformers_memory_efficient_attention()

        downsample_factor = 2
        ratio = 0.38
        merge_method = "downsample" if method == "todo" else "similarity"
        merge_tokens = "keys/values" if method == "todo" else "all"

        if height_width == 1024:
            downsample_factor = 2
            ratio = 0.75
            downsample_factor_level_2 = 1
            ratio_level_2 = 0.0
        elif height_width == 1536:
            downsample_factor = 3
            ratio = 0.89
            downsample_factor_level_2 = 1
            ratio_level_2 = 0.0
        elif height_width == 2048:
            downsample_factor = 4
            ratio = 0.9375
            downsample_factor_level_2 = 2
            ratio_level_2 = 0.75

        token_merge_args = {"ratio": ratio,
                    "merge_tokens": merge_tokens,
                    "merge_method": merge_method,
                    "downsample_method": "nearest",
                    "downsample_factor": downsample_factor,
                    "timestep_threshold_switch": 0.0,
                    "timestep_threshold_stop": 0.0,
                    "downsample_factor_level_2": downsample_factor_level_2,
                    "ratio_level_2": ratio_level_2
                    }

        l_r = torch.rand(1).item()
        torch.manual_seed(seed)
        start_time_base = time.time()
        base_img = pipe(prompt,
                        num_inference_steps=steps, height=height_width, width=height_width,
                        negative_prompt=negative_prompt,
                        guidance_scale=guidance_scale).images[0]
        end_time_base = time.time()

        patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)

        torch.manual_seed(seed)
        start_time_merge = time.time()
        merged_img = pipe(prompt,
                          num_inference_steps=steps, height=height_width, width=height_width,
                          negative_prompt=negative_prompt,
                          guidance_scale=guidance_scale).images[0]
        end_time_merge = time.time()

        base_img = base_img.convert("RGBA")
        merged_img = merged_img.convert("RGBA")
        merged_img = np.array(merged_img)
        halfh, halfw = height_width // 2, height_width // 2
        merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
        merged_img = Image.fromarray(merged_img)
        final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")

        if l_r > 0.5:
            left_img = base_img
            right_img = merged_img
        else:
            left_img = merged_img
            right_img = base_img

        final_img.paste(left_img, (0, 0))
        final_img.paste(right_img, (height_width, 0))

        which_image(final_img, width=height_width)


        result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"

        return final_img, result


    gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
                                guidance_scale, method], outputs=[output_image, result])

demo.launch(share=True)