hari7261 commited on
Commit
a7d86f5
·
verified ·
1 Parent(s): b9acbf8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -155
app.py CHANGED
@@ -1,27 +1,32 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import AutoPipelineForText2Image
5
  import torch
6
  from time import time
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model_repo_id = "stabilityai/sdxl-turbo"
10
 
11
- # Load model with optimizations
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- pipe = AutoPipelineForText2Image.from_pretrained(
15
- model_repo_id,
16
- torch_dtype=torch_dtype,
17
- variant="fp16",
18
- use_safetensors=True
19
- )
20
- pipe.enable_model_cpu_offload() # For better memory management
21
- pipe.enable_xformers_memory_efficient_attention() # Faster attention
22
- else:
23
- torch_dtype = torch.float32
24
- pipe = AutoPipelineForText2Image.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
 
 
 
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 1024
@@ -37,15 +42,15 @@ def infer(
37
  num_inference_steps,
38
  progress=gr.Progress(track_tqdm=True),
39
  ):
40
- start_time = time()
41
-
42
- if randomize_seed:
43
- seed = random.randint(0, MAX_SEED)
 
44
 
45
- generator = torch.Generator(device=device).manual_seed(seed)
46
 
47
- # Generate image with progress updates
48
- with torch.no_grad():
49
  image = pipe(
50
  prompt=prompt,
51
  negative_prompt=negative_prompt,
@@ -56,9 +61,12 @@ def infer(
56
  generator=generator,
57
  ).images[0]
58
 
59
- gen_time = time() - start_time
60
-
61
- return image, seed, f"Generated in {gen_time:.2f}s"
 
 
 
62
 
63
  examples = [
64
  ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
@@ -68,50 +76,19 @@ examples = [
68
 
69
  css = """
70
  :root {
71
- --primary-50: #f0f0ff;
72
- --primary-100: #e0e0ff;
73
- --primary-200: #c7c7fe;
74
- --primary-300: #a5a5fc;
75
- --primary-400: #8181f8;
76
- --primary-500: #6e6af0;
77
- --primary-600: #5a56e4;
78
- --primary-700: #4a46c9;
79
- --primary-800: #3e3ba3;
80
- --primary-900: #383682;
81
- --primary-950: #211f4d;
82
-
83
- --surface-0: 255 255 255;
84
- --surface-50: 248 250 252;
85
- --surface-100: 241 245 249;
86
- --surface-200: 226 232 240;
87
- --surface-300: 203 213 225;
88
- --surface-400: 148 163 184;
89
- --surface-500: 100 116 139;
90
- --surface-600: 71 85 105;
91
- --surface-700: 45 55 72;
92
- --surface-800: 30 41 59;
93
- --surface-900: 15 23 42;
94
- --surface-950: 3 6 23;
95
-
96
- --text-primary: rgb(var(--surface-900));
97
- --text-secondary: rgb(var(--surface-600));
98
  }
99
 
100
  .dark {
101
- --primary-50: #211f4d;
102
- --primary-100: #383682;
103
- --primary-200: #3e3ba3;
104
- --primary-300: #4a46c9;
105
- --primary-400: #5a56e4;
106
- --primary-500: #6e6af0;
107
- --primary-600: #8181f8;
108
- --primary-700: #a5a5fc;
109
- --primary-800: #c7c7fe;
110
- --primary-900: #e0e0ff;
111
- --primary-950: #f0f0ff;
112
-
113
- --text-primary: rgb(var(--surface-100));
114
- --text-secondary: rgb(var(--surface-300));
115
  }
116
 
117
  #col-container {
@@ -128,28 +105,21 @@ css = """
128
  .header h1 {
129
  font-size: 2.5rem;
130
  font-weight: 700;
131
- color: var(--primary-500);
132
  margin-bottom: 10px;
133
  }
134
 
135
  .header p {
136
- color: var(--text-secondary);
 
137
  }
138
 
139
  .prompt-container, .result-container, .advanced-settings {
140
- background-color: rgb(var(--surface-50));
141
  border-radius: 12px;
142
  padding: 20px;
143
- box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
144
  margin-bottom: 20px;
145
- border: 1px solid rgb(var(--surface-200));
146
- }
147
-
148
- .dark .prompt-container,
149
- .dark .result-container,
150
- .dark .advanced-settings {
151
- background-color: rgb(var(--surface-800));
152
- border-color: rgb(var(--surface-700));
153
  }
154
 
155
  .advanced-settings .form {
@@ -158,10 +128,6 @@ css = """
158
  gap: 16px;
159
  }
160
 
161
- .advanced-settings .form > * {
162
- margin-bottom: 0 !important;
163
- }
164
-
165
  .control-row {
166
  display: flex;
167
  gap: 10px;
@@ -169,13 +135,13 @@ css = """
169
  }
170
 
171
  .btn-primary {
172
- background: var(--primary-500) !important;
173
  border: none !important;
174
  color: white !important;
175
  }
176
 
177
  .btn-primary:hover {
178
- background: var(--primary-600) !important;
179
  }
180
 
181
  .examples {
@@ -186,44 +152,23 @@ css = """
186
  }
187
 
188
  .example-prompt {
189
- background: rgb(var(--surface-100));
190
  padding: 12px;
191
  border-radius: 8px;
192
  cursor: pointer;
193
  transition: all 0.2s;
194
- border: 1px solid rgb(var(--surface-200));
195
- }
196
-
197
- .dark .example-prompt {
198
- background: rgb(var(--surface-700));
199
- border-color: rgb(var(--surface-600));
200
  }
201
 
202
  .example-prompt:hover {
203
- background: var(--primary-100);
204
  transform: translateY(-2px);
205
- border-color: var(--primary-300);
206
- }
207
-
208
- .dark .example-prompt:hover {
209
- background: var(--primary-800);
210
- border-color: var(--primary-600);
211
- }
212
-
213
- .example-img {
214
- width: 100%;
215
- height: 120px;
216
- object-fit: cover;
217
- border-radius: 6px;
218
- margin-bottom: 8px;
219
  }
220
 
221
- /* Theme toggle button */
222
  .theme-toggle {
223
  position: absolute;
224
  top: 20px;
225
  right: 20px;
226
- background: var(--primary-100);
227
  border: none;
228
  border-radius: 50%;
229
  width: 40px;
@@ -232,47 +177,12 @@ css = """
232
  align-items: center;
233
  justify-content: center;
234
  cursor: pointer;
235
- transition: all 0.2s;
236
- }
237
-
238
- .theme-toggle:hover {
239
- background: var(--primary-200);
240
- }
241
-
242
- .dark .theme-toggle {
243
- background: var(--primary-800);
244
- }
245
-
246
- .dark .theme-toggle:hover {
247
- background: var(--primary-700);
248
  }
249
 
250
  @media (max-width: 768px) {
251
  .advanced-settings .form {
252
  grid-template-columns: 1fr;
253
  }
254
-
255
- .theme-toggle {
256
- top: 10px;
257
- right: 10px;
258
- }
259
- }
260
-
261
- /* Loading animation */
262
- @keyframes spin {
263
- 0% { transform: rotate(0deg); }
264
- 100% { transform: rotate(360deg); }
265
- }
266
-
267
- .loading-spinner {
268
- display: inline-block;
269
- width: 20px;
270
- height: 20px;
271
- border: 3px solid rgba(255, 255, 255, 0.3);
272
- border-radius: 50%;
273
- border-top-color: white;
274
- animation: spin 1s ease-in-out infinite;
275
- margin-right: 8px;
276
  }
277
  """
278
 
@@ -294,13 +204,6 @@ document.addEventListener('DOMContentLoaded', () => {
294
  themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
295
  themeToggle.onclick = toggleTheme;
296
  document.body.appendChild(themeToggle);
297
-
298
- // Update icon when theme changes
299
- document.body.addEventListener('click', (e) => {
300
- if (e.target === themeToggle) {
301
- themeToggle.innerHTML = document.body.classList.contains('dark') ? '☀️' : '🌙';
302
- }
303
- });
304
  });
305
  """
306
 
@@ -338,8 +241,7 @@ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
338
  result = gr.Image(
339
  label="Generated Image",
340
  show_label=False,
341
- height=500,
342
- elem_id="output-image"
343
  )
344
  with gr.Row():
345
  seed_info = gr.Textbox(
@@ -372,7 +274,6 @@ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
372
  randomize_seed = gr.Checkbox(
373
  label="Randomize seed",
374
  value=True,
375
- interactive=True
376
  )
377
 
378
  with gr.Column():
@@ -415,8 +316,7 @@ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
415
  examples=[[example[0], example[1], example[2]]],
416
  inputs=[prompt, width, height],
417
  label="",
418
- examples_per_page=20,
419
- elem_id=f"example-{example[0][:10]}"
420
  )
421
 
422
  run_button.click(
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ from diffusers import DiffusionPipeline
5
  import torch
6
  from time import time
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model_repo_id = "stabilityai/sdxl-turbo"
10
 
11
+ # Simplified model loading
12
+ try:
13
+ if torch.cuda.is_available():
14
+ torch_dtype = torch.float16
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ model_repo_id,
17
+ torch_dtype=torch_dtype,
18
+ variant="fp16",
19
+ use_safetensors=True
20
+ ).to(device)
21
+ else:
22
+ torch_dtype = torch.float32
23
+ pipe = DiffusionPipeline.from_pretrained(
24
+ model_repo_id,
25
+ torch_dtype=torch_dtype
26
+ ).to(device)
27
+ except Exception as e:
28
+ print(f"Error loading model: {e}")
29
+ raise
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1024
 
42
  num_inference_steps,
43
  progress=gr.Progress(track_tqdm=True),
44
  ):
45
+ try:
46
+ start_time = time()
47
+
48
+ if randomize_seed:
49
+ seed = random.randint(0, MAX_SEED)
50
 
51
+ generator = torch.Generator(device=device).manual_seed(seed)
52
 
53
+ # Generate image
 
54
  image = pipe(
55
  prompt=prompt,
56
  negative_prompt=negative_prompt,
 
61
  generator=generator,
62
  ).images[0]
63
 
64
+ gen_time = time() - start_time
65
+
66
+ return image, seed, f"Generated in {gen_time:.2f}s"
67
+ except Exception as e:
68
+ print(f"Error during inference: {e}")
69
+ return None, seed, f"Error: {str(e)}"
70
 
71
  examples = [
72
  ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
 
76
 
77
  css = """
78
  :root {
79
+ --primary: #6e6af0;
80
+ --secondary: #f5f5f7;
81
+ --accent: #f5f5f7;
82
+ --text: #1e1e1e;
83
+ --shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  }
85
 
86
  .dark {
87
+ --primary: #a5a5fc;
88
+ --secondary: #2d3748;
89
+ --accent: #4a5568;
90
+ --text: #f7fafc;
91
+ --shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
 
 
 
 
 
 
 
 
 
92
  }
93
 
94
  #col-container {
 
105
  .header h1 {
106
  font-size: 2.5rem;
107
  font-weight: 700;
108
+ color: var(--primary);
109
  margin-bottom: 10px;
110
  }
111
 
112
  .header p {
113
+ color: var(--text);
114
+ opacity: 0.8;
115
  }
116
 
117
  .prompt-container, .result-container, .advanced-settings {
118
+ background: var(--secondary);
119
  border-radius: 12px;
120
  padding: 20px;
121
+ box-shadow: var(--shadow);
122
  margin-bottom: 20px;
 
 
 
 
 
 
 
 
123
  }
124
 
125
  .advanced-settings .form {
 
128
  gap: 16px;
129
  }
130
 
 
 
 
 
131
  .control-row {
132
  display: flex;
133
  gap: 10px;
 
135
  }
136
 
137
  .btn-primary {
138
+ background: var(--primary) !important;
139
  border: none !important;
140
  color: white !important;
141
  }
142
 
143
  .btn-primary:hover {
144
+ opacity: 0.9 !important;
145
  }
146
 
147
  .examples {
 
152
  }
153
 
154
  .example-prompt {
155
+ background: var(--secondary);
156
  padding: 12px;
157
  border-radius: 8px;
158
  cursor: pointer;
159
  transition: all 0.2s;
 
 
 
 
 
 
160
  }
161
 
162
  .example-prompt:hover {
 
163
  transform: translateY(-2px);
164
+ box-shadow: var(--shadow);
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  }
166
 
 
167
  .theme-toggle {
168
  position: absolute;
169
  top: 20px;
170
  right: 20px;
171
+ background: var(--secondary);
172
  border: none;
173
  border-radius: 50%;
174
  width: 40px;
 
177
  align-items: center;
178
  justify-content: center;
179
  cursor: pointer;
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
 
182
  @media (max-width: 768px) {
183
  .advanced-settings .form {
184
  grid-template-columns: 1fr;
185
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  }
187
  """
188
 
 
204
  themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
205
  themeToggle.onclick = toggleTheme;
206
  document.body.appendChild(themeToggle);
 
 
 
 
 
 
 
207
  });
208
  """
209
 
 
241
  result = gr.Image(
242
  label="Generated Image",
243
  show_label=False,
244
+ height=500
 
245
  )
246
  with gr.Row():
247
  seed_info = gr.Textbox(
 
274
  randomize_seed = gr.Checkbox(
275
  label="Randomize seed",
276
  value=True,
 
277
  )
278
 
279
  with gr.Column():
 
316
  examples=[[example[0], example[1], example[2]]],
317
  inputs=[prompt, width, height],
318
  label="",
319
+ examples_per_page=20
 
320
  )
321
 
322
  run_button.click(