blanchon commited on
Commit
771c08c
·
1 Parent(s): 4a4a63e

🔥 Add app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -0
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import gradio as gr
4
+ import qrcode
5
+ import torch
6
+ from diffusers import (
7
+ ControlNetModel,
8
+ EulerAncestralDiscreteScheduler,
9
+ StableDiffusionControlNetPipeline,
10
+ )
11
+ from gradio.components import Image, Radio, Slider, Textbox, Number
12
+ from PIL import Image as PilImage
13
+ from typing_extensions import Literal
14
+
15
+
16
+ def main():
17
+ device = (
18
+ 'cuda' if torch.cuda.is_available()
19
+ else 'mps' if torch.backends.mps.is_available()
20
+ else 'cpu'
21
+ )
22
+
23
+ controlnet_tile = ControlNetModel.from_pretrained(
24
+ "lllyasviel/control_v11f1e_sd15_tile",
25
+ torch_dtype=torch.float16,
26
+ use_safetensors=False
27
+ ).to(device)
28
+
29
+ controlnet_brightness = ControlNetModel.from_pretrained(
30
+ "ioclab/control_v1p_sd15_brightness",
31
+ torch_dtype=torch.float16,
32
+ use_safetensors=True
33
+ ).to(device)
34
+
35
+ def make_pipe(hf_repo: str, device: str) -> StableDiffusionControlNetPipeline:
36
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
37
+ hf_repo,
38
+ controlnet=[controlnet_tile, controlnet_brightness],
39
+ torch_dtype=torch.float16,
40
+ )
41
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
42
+ # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
43
+ return pipe.to(device)
44
+
45
+ pipes = {
46
+ "DreamShaper": make_pipe("Lykon/DreamShaper", "cpu"),
47
+ # "Realistic Vision V1.4": make_pipe("SG161222/Realistic_Vision_V1.4", "cpu"),
48
+ # "OpenJourney": make_pipe("prompthero/openjourney", "cpu"),
49
+ # "Anything V3": make_pipe("Linaqruf/anything-v3.0", "cpu"),
50
+ }
51
+
52
+ def move_pipe(hf_repo: str):
53
+ for pipe_name, pipe in pipes.items():
54
+ if pipe_name != hf_repo:
55
+ pipe.to("cpu")
56
+ return pipes[hf_repo].to(device)
57
+
58
+ def predict(
59
+ model: Literal[
60
+ "DreamShaper",
61
+ # "Realistic Vision V1.4",
62
+ # "OpenJourney",
63
+ # "Anything V3"
64
+ ],
65
+ qrcode_data: str,
66
+ prompt: str,
67
+ negative_prompt: Optional[str] = None,
68
+ num_inference_steps: int = 100,
69
+ guidance_scale: int = 9,
70
+ controlnet_conditioning_tile: float = 0.25,
71
+ controlnet_conditioning_brightness: float = 0.45,
72
+ seed: int = 1331,
73
+ ) -> PilImage:
74
+ generator = torch.Generator(device="cuda").manual_seed(seed)
75
+ if model == "DreamShaper":
76
+ pipe = move_pipe("DreamShaper")
77
+ # elif model == "Realistic Vision V1.4":
78
+ # pipe = move_pipe("Realistic Vision V1.4")
79
+ # elif model == "OpenJourney":
80
+ # pipe = move_pipe("OpenJourney")
81
+ # elif model == "Anything V3":
82
+ # pipe = move_pipe("Anything V3")
83
+
84
+
85
+ qr = qrcode.QRCode(
86
+ error_correction=qrcode.constants.ERROR_CORRECT_H,
87
+ box_size=11,
88
+ border=9,
89
+ )
90
+ qr.add_data(qrcode_data)
91
+ qr.make(fit=True)
92
+ qrcode_image = qr.make_image(
93
+ fill_color="black",
94
+ back_color="white"
95
+ ).convert("RGB")
96
+ qrcode_image = qrcode_image.resize((512, 512), PilImage.LANCZOS)
97
+
98
+ image = pipe(
99
+ prompt,
100
+ [qrcode_image, qrcode_image],
101
+ num_inference_steps=num_inference_steps,
102
+ generator=generator,
103
+ negative_prompt=negative_prompt,
104
+ guidance_scale=guidance_scale,
105
+ controlnet_conditioning_scale=[
106
+ controlnet_conditioning_tile,
107
+ controlnet_conditioning_brightness
108
+ ]
109
+ ).images[0]
110
+
111
+ return image
112
+
113
+
114
+ ui = gr.Interface(
115
+ fn=predict,
116
+ inputs=[
117
+ Radio(
118
+ value="DreamShaper",
119
+ label="Model",
120
+ choices=[
121
+ "DreamShaper",
122
+ # "Realistic Vision V1.4",
123
+ # "OpenJourney",
124
+ # "Anything V3"
125
+ ],
126
+ ),
127
+ Textbox(
128
+ value="https://twitter.com/JulienBlanchon",
129
+ label="QR Code Data",
130
+ ),
131
+ Textbox(
132
+ value="Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
133
+ label="Prompt",
134
+ ),
135
+ Textbox(
136
+ value="logo, watermark, signature, text, BadDream, UnrealisticDream",
137
+ label="Negative Prompt",
138
+ optional=True
139
+ ),
140
+ Slider(
141
+ value=100,
142
+ label="Number of Inference Steps",
143
+ minimum=10,
144
+ maximum=400,
145
+ step=1,
146
+ ),
147
+ Slider(
148
+ value=9,
149
+ label="Guidance Scale",
150
+ minimum=1,
151
+ maximum=20,
152
+ step=1,
153
+ ),
154
+ Slider(
155
+ value=0.25,
156
+ label="Controlnet Conditioning Tile",
157
+ minimum=0.0,
158
+ maximum=1.0,
159
+ step=0.05,
160
+
161
+ ),
162
+ Slider(
163
+ value=0.45,
164
+ label="Controlnet Conditioning Brightness",
165
+ minimum=0.0,
166
+ maximum=1.0,
167
+ step=0.05,
168
+ ),
169
+ Number(
170
+ value=1,
171
+ label="Seed",
172
+ precision=0,
173
+ ),
174
+
175
+ ],
176
+ outputs=Image(
177
+ label="Generated Image",
178
+ type="pil",
179
+ ),
180
+ examples=[
181
+ [
182
+ "DreamShaper",
183
+ "https://twitter.com/JulienBlanchon",
184
+ "Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
185
+ "logo, watermark, signature, text, BadDream, UnrealisticDream",
186
+ 100,
187
+ 9,
188
+ 0.25,
189
+ 0.45,
190
+ 1,
191
+ ],
192
+ # [
193
+ # "Anything V3",
194
+ # "https://twitter.com/JulienBlanchon",
195
+ # "Japanese ramen with chopsticks, egg and steam, ultra detailed 8k",
196
+ # "logo, watermark, signature, text, BadDream, UnrealisticDream",
197
+ # 100,
198
+ # 9,
199
+ # 0.25,
200
+ # 0.60,
201
+ # 1,
202
+ # ],
203
+ [
204
+ "DreamShaper",
205
+ "https://twitter.com/JulienBlanchon",
206
+ "processor, chipset, electricity, black and white board",
207
+ "logo, watermark, signature, text, BadDream, UnrealisticDream",
208
+ 300,
209
+ 9,
210
+ 0.50,
211
+ 0.30,
212
+ 1,
213
+ ],
214
+ ],
215
+ cache_examples=True,
216
+ title="Stable Diffusion QR Code Controlnet",
217
+ description="Generate QR Code with Stable Diffusion and Controlnet",
218
+ allow_flagging="never",
219
+ max_batch_size=1,
220
+ )
221
+
222
+ ui.queue(concurrency_count=10).launch()
223
+
224
+ if __name__ == "__main__":
225
+ main()