hari7261 commited on
Commit
b9acbf8
·
verified ·
1 Parent(s): 4296d01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -52
app.py CHANGED
@@ -1,19 +1,27 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  model_repo_id = "stabilityai/sdxl-turbo"
9
 
 
10
  if torch.cuda.is_available():
11
  torch_dtype = torch.float16
 
 
 
 
 
 
 
 
12
  else:
13
  torch_dtype = torch.float32
14
-
15
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
16
- pipe = pipe.to(device)
17
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
@@ -29,22 +37,28 @@ def infer(
29
  num_inference_steps,
30
  progress=gr.Progress(track_tqdm=True),
31
  ):
 
 
32
  if randomize_seed:
33
  seed = random.randint(0, MAX_SEED)
34
 
35
- generator = torch.Generator().manual_seed(seed)
36
 
37
- image = pipe(
38
- prompt=prompt,
39
- negative_prompt=negative_prompt,
40
- guidance_scale=guidance_scale,
41
- num_inference_steps=num_inference_steps,
42
- width=width,
43
- height=height,
44
- generator=generator,
45
- ).images[0]
 
 
46
 
47
- return image, seed
 
 
48
 
49
  examples = [
50
  ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
@@ -54,11 +68,50 @@ examples = [
54
 
55
  css = """
56
  :root {
57
- --primary: #6e6af0;
58
- --secondary: #f5f5f7;
59
- --accent: #f5f5f7;
60
- --text: #1e1e1e;
61
- --shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  }
63
 
64
  #col-container {
@@ -75,36 +128,28 @@ css = """
75
  .header h1 {
76
  font-size: 2.5rem;
77
  font-weight: 700;
78
- color: var(--primary);
79
  margin-bottom: 10px;
80
  }
81
 
82
  .header p {
83
- color: var(--text);
84
- opacity: 0.8;
85
  }
86
 
87
- .prompt-container {
88
- background: white;
89
  border-radius: 12px;
90
  padding: 20px;
91
- box-shadow: var(--shadow);
92
  margin-bottom: 20px;
 
93
  }
94
 
95
- .result-container {
96
- background: white;
97
- border-radius: 12px;
98
- padding: 20px;
99
- box-shadow: var(--shadow);
100
- margin-bottom: 20px;
101
- }
102
-
103
- .advanced-settings {
104
- background: white;
105
- border-radius: 12px;
106
- padding: 20px;
107
- box-shadow: var(--shadow);
108
  }
109
 
110
  .advanced-settings .form {
@@ -124,12 +169,13 @@ css = """
124
  }
125
 
126
  .btn-primary {
127
- background: var(--primary) !important;
128
  border: none !important;
 
129
  }
130
 
131
  .btn-primary:hover {
132
- opacity: 0.9 !important;
133
  }
134
 
135
  .examples {
@@ -140,16 +186,28 @@ css = """
140
  }
141
 
142
  .example-prompt {
143
- background: var(--secondary);
144
  padding: 12px;
145
  border-radius: 8px;
146
  cursor: pointer;
147
  transition: all 0.2s;
 
 
 
 
 
 
148
  }
149
 
150
  .example-prompt:hover {
151
- background: #e0e0e8;
152
  transform: translateY(-2px);
 
 
 
 
 
 
153
  }
154
 
155
  .example-img {
@@ -160,14 +218,93 @@ css = """
160
  margin-bottom: 8px;
161
  }
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  @media (max-width: 768px) {
164
  .advanced-settings .form {
165
  grid-template-columns: 1fr;
166
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  }
168
  """
169
 
170
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  with gr.Column(elem_id="col-container"):
172
  with gr.Column(visible=True) as header:
173
  gr.Markdown(
@@ -204,11 +341,15 @@ with gr.Blocks(css=css) as demo:
204
  height=500,
205
  elem_id="output-image"
206
  )
207
- seed_info = gr.Textbox(
208
- label="Seed",
209
- interactive=False,
210
- visible=False
211
- )
 
 
 
 
212
 
213
  with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
214
  with gr.Column(elem_classes="form"):
@@ -290,7 +431,7 @@ with gr.Blocks(css=css) as demo:
290
  guidance_scale,
291
  num_inference_steps,
292
  ],
293
- outputs=[result, seed_info],
294
  )
295
  prompt.submit(
296
  fn=infer,
@@ -304,8 +445,8 @@ with gr.Blocks(css=css) as demo:
304
  guidance_scale,
305
  num_inference_steps,
306
  ],
307
- outputs=[result, seed_info],
308
  )
309
 
310
  if __name__ == "__main__":
311
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ from diffusers import AutoPipelineForText2Image
5
  import torch
6
+ from time import time
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model_repo_id = "stabilityai/sdxl-turbo"
10
 
11
+ # Load model with optimizations
12
  if torch.cuda.is_available():
13
  torch_dtype = torch.float16
14
+ pipe = AutoPipelineForText2Image.from_pretrained(
15
+ model_repo_id,
16
+ torch_dtype=torch_dtype,
17
+ variant="fp16",
18
+ use_safetensors=True
19
+ )
20
+ pipe.enable_model_cpu_offload() # For better memory management
21
+ pipe.enable_xformers_memory_efficient_attention() # Faster attention
22
  else:
23
  torch_dtype = torch.float32
24
+ pipe = AutoPipelineForText2Image.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
 
37
  num_inference_steps,
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
40
+ start_time = time()
41
+
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
 
45
+ generator = torch.Generator(device=device).manual_seed(seed)
46
 
47
+ # Generate image with progress updates
48
+ with torch.no_grad():
49
+ image = pipe(
50
+ prompt=prompt,
51
+ negative_prompt=negative_prompt,
52
+ guidance_scale=guidance_scale,
53
+ num_inference_steps=num_inference_steps,
54
+ width=width,
55
+ height=height,
56
+ generator=generator,
57
+ ).images[0]
58
 
59
+ gen_time = time() - start_time
60
+
61
+ return image, seed, f"Generated in {gen_time:.2f}s"
62
 
63
  examples = [
64
  ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
 
68
 
69
  css = """
70
  :root {
71
+ --primary-50: #f0f0ff;
72
+ --primary-100: #e0e0ff;
73
+ --primary-200: #c7c7fe;
74
+ --primary-300: #a5a5fc;
75
+ --primary-400: #8181f8;
76
+ --primary-500: #6e6af0;
77
+ --primary-600: #5a56e4;
78
+ --primary-700: #4a46c9;
79
+ --primary-800: #3e3ba3;
80
+ --primary-900: #383682;
81
+ --primary-950: #211f4d;
82
+
83
+ --surface-0: 255 255 255;
84
+ --surface-50: 248 250 252;
85
+ --surface-100: 241 245 249;
86
+ --surface-200: 226 232 240;
87
+ --surface-300: 203 213 225;
88
+ --surface-400: 148 163 184;
89
+ --surface-500: 100 116 139;
90
+ --surface-600: 71 85 105;
91
+ --surface-700: 45 55 72;
92
+ --surface-800: 30 41 59;
93
+ --surface-900: 15 23 42;
94
+ --surface-950: 3 6 23;
95
+
96
+ --text-primary: rgb(var(--surface-900));
97
+ --text-secondary: rgb(var(--surface-600));
98
+ }
99
+
100
+ .dark {
101
+ --primary-50: #211f4d;
102
+ --primary-100: #383682;
103
+ --primary-200: #3e3ba3;
104
+ --primary-300: #4a46c9;
105
+ --primary-400: #5a56e4;
106
+ --primary-500: #6e6af0;
107
+ --primary-600: #8181f8;
108
+ --primary-700: #a5a5fc;
109
+ --primary-800: #c7c7fe;
110
+ --primary-900: #e0e0ff;
111
+ --primary-950: #f0f0ff;
112
+
113
+ --text-primary: rgb(var(--surface-100));
114
+ --text-secondary: rgb(var(--surface-300));
115
  }
116
 
117
  #col-container {
 
128
  .header h1 {
129
  font-size: 2.5rem;
130
  font-weight: 700;
131
+ color: var(--primary-500);
132
  margin-bottom: 10px;
133
  }
134
 
135
  .header p {
136
+ color: var(--text-secondary);
 
137
  }
138
 
139
+ .prompt-container, .result-container, .advanced-settings {
140
+ background-color: rgb(var(--surface-50));
141
  border-radius: 12px;
142
  padding: 20px;
143
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
144
  margin-bottom: 20px;
145
+ border: 1px solid rgb(var(--surface-200));
146
  }
147
 
148
+ .dark .prompt-container,
149
+ .dark .result-container,
150
+ .dark .advanced-settings {
151
+ background-color: rgb(var(--surface-800));
152
+ border-color: rgb(var(--surface-700));
 
 
 
 
 
 
 
 
153
  }
154
 
155
  .advanced-settings .form {
 
169
  }
170
 
171
  .btn-primary {
172
+ background: var(--primary-500) !important;
173
  border: none !important;
174
+ color: white !important;
175
  }
176
 
177
  .btn-primary:hover {
178
+ background: var(--primary-600) !important;
179
  }
180
 
181
  .examples {
 
186
  }
187
 
188
  .example-prompt {
189
+ background: rgb(var(--surface-100));
190
  padding: 12px;
191
  border-radius: 8px;
192
  cursor: pointer;
193
  transition: all 0.2s;
194
+ border: 1px solid rgb(var(--surface-200));
195
+ }
196
+
197
+ .dark .example-prompt {
198
+ background: rgb(var(--surface-700));
199
+ border-color: rgb(var(--surface-600));
200
  }
201
 
202
  .example-prompt:hover {
203
+ background: var(--primary-100);
204
  transform: translateY(-2px);
205
+ border-color: var(--primary-300);
206
+ }
207
+
208
+ .dark .example-prompt:hover {
209
+ background: var(--primary-800);
210
+ border-color: var(--primary-600);
211
  }
212
 
213
  .example-img {
 
218
  margin-bottom: 8px;
219
  }
220
 
221
+ /* Theme toggle button */
222
+ .theme-toggle {
223
+ position: absolute;
224
+ top: 20px;
225
+ right: 20px;
226
+ background: var(--primary-100);
227
+ border: none;
228
+ border-radius: 50%;
229
+ width: 40px;
230
+ height: 40px;
231
+ display: flex;
232
+ align-items: center;
233
+ justify-content: center;
234
+ cursor: pointer;
235
+ transition: all 0.2s;
236
+ }
237
+
238
+ .theme-toggle:hover {
239
+ background: var(--primary-200);
240
+ }
241
+
242
+ .dark .theme-toggle {
243
+ background: var(--primary-800);
244
+ }
245
+
246
+ .dark .theme-toggle:hover {
247
+ background: var(--primary-700);
248
+ }
249
+
250
  @media (max-width: 768px) {
251
  .advanced-settings .form {
252
  grid-template-columns: 1fr;
253
  }
254
+
255
+ .theme-toggle {
256
+ top: 10px;
257
+ right: 10px;
258
+ }
259
+ }
260
+
261
+ /* Loading animation */
262
+ @keyframes spin {
263
+ 0% { transform: rotate(0deg); }
264
+ 100% { transform: rotate(360deg); }
265
+ }
266
+
267
+ .loading-spinner {
268
+ display: inline-block;
269
+ width: 20px;
270
+ height: 20px;
271
+ border: 3px solid rgba(255, 255, 255, 0.3);
272
+ border-radius: 50%;
273
+ border-top-color: white;
274
+ animation: spin 1s ease-in-out infinite;
275
+ margin-right: 8px;
276
  }
277
  """
278
 
279
+ js = """
280
+ function toggleTheme() {
281
+ const body = document.body;
282
+ body.classList.toggle('dark');
283
+ localStorage.setItem('gradio-theme', body.classList.contains('dark') ? 'dark' : 'light');
284
+ }
285
+
286
+ document.addEventListener('DOMContentLoaded', () => {
287
+ const savedTheme = localStorage.getItem('gradio-theme') || 'light';
288
+ if (savedTheme === 'dark') {
289
+ document.body.classList.add('dark');
290
+ }
291
+
292
+ const themeToggle = document.createElement('button');
293
+ themeToggle.className = 'theme-toggle';
294
+ themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
295
+ themeToggle.onclick = toggleTheme;
296
+ document.body.appendChild(themeToggle);
297
+
298
+ // Update icon when theme changes
299
+ document.body.addEventListener('click', (e) => {
300
+ if (e.target === themeToggle) {
301
+ themeToggle.innerHTML = document.body.classList.contains('dark') ? '☀️' : '🌙';
302
+ }
303
+ });
304
+ });
305
+ """
306
+
307
+ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
308
  with gr.Column(elem_id="col-container"):
309
  with gr.Column(visible=True) as header:
310
  gr.Markdown(
 
341
  height=500,
342
  elem_id="output-image"
343
  )
344
+ with gr.Row():
345
+ seed_info = gr.Textbox(
346
+ label="Seed",
347
+ interactive=False
348
+ )
349
+ time_info = gr.Textbox(
350
+ label="Generation Time",
351
+ interactive=False
352
+ )
353
 
354
  with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
355
  with gr.Column(elem_classes="form"):
 
431
  guidance_scale,
432
  num_inference_steps,
433
  ],
434
+ outputs=[result, seed_info, time_info],
435
  )
436
  prompt.submit(
437
  fn=infer,
 
445
  guidance_scale,
446
  num_inference_steps,
447
  ],
448
+ outputs=[result, seed_info, time_info],
449
  )
450
 
451
  if __name__ == "__main__":
452
+ demo.queue(api_open=False).launch()