hari7261 commited on
Commit
4472a1c
·
verified ·
1 Parent(s): a7d86f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -239
app.py CHANGED
@@ -1,36 +1,27 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
4
  from diffusers import DiffusionPipeline
5
  import torch
6
- from time import time
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo"
10
-
11
- # Simplified model loading
12
- try:
13
- if torch.cuda.is_available():
14
- torch_dtype = torch.float16
15
- pipe = DiffusionPipeline.from_pretrained(
16
- model_repo_id,
17
- torch_dtype=torch_dtype,
18
- variant="fp16",
19
- use_safetensors=True
20
- ).to(device)
21
- else:
22
- torch_dtype = torch.float32
23
- pipe = DiffusionPipeline.from_pretrained(
24
- model_repo_id,
25
- torch_dtype=torch_dtype
26
- ).to(device)
27
- except Exception as e:
28
- print(f"Error loading model: {e}")
29
- raise
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
  MAX_IMAGE_SIZE = 1024
33
 
 
 
34
  def infer(
35
  prompt,
36
  negative_prompt,
@@ -42,36 +33,28 @@ def infer(
42
  num_inference_steps,
43
  progress=gr.Progress(track_tqdm=True),
44
  ):
45
- try:
46
- start_time = time()
47
-
48
- if randomize_seed:
49
- seed = random.randint(0, MAX_SEED)
50
-
51
- generator = torch.Generator(device=device).manual_seed(seed)
52
-
53
- # Generate image
54
- image = pipe(
55
- prompt=prompt,
56
- negative_prompt=negative_prompt,
57
- guidance_scale=guidance_scale,
58
- num_inference_steps=num_inference_steps,
59
- width=width,
60
- height=height,
61
- generator=generator,
62
- ).images[0]
63
-
64
- gen_time = time() - start_time
65
-
66
- return image, seed, f"Generated in {gen_time:.2f}s"
67
- except Exception as e:
68
- print(f"Error during inference: {e}")
69
- return None, seed, f"Error: {str(e)}"
70
 
71
  examples = [
72
- ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 1024, 1024],
73
- ["An astronaut riding a green horse", 768, 768],
74
- ["A delicious ceviche cheesecake slice", 896, 896],
75
  ]
76
 
77
  css = """
@@ -83,14 +66,6 @@ css = """
83
  --shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
84
  }
85
 
86
- .dark {
87
- --primary: #a5a5fc;
88
- --secondary: #2d3748;
89
- --accent: #4a5568;
90
- --text: #f7fafc;
91
- --shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
92
- }
93
-
94
  #col-container {
95
  margin: 0 auto;
96
  max-width: 800px;
@@ -109,35 +84,32 @@ css = """
109
  margin-bottom: 10px;
110
  }
111
 
112
- .header p {
113
- color: var(--text);
114
- opacity: 0.8;
115
- }
116
-
117
- .prompt-container, .result-container, .advanced-settings {
118
- background: var(--secondary);
119
  border-radius: 12px;
120
  padding: 20px;
121
  box-shadow: var(--shadow);
122
  margin-bottom: 20px;
123
  }
124
 
125
- .advanced-settings .form {
126
- display: grid;
127
- grid-template-columns: 1fr 1fr;
128
- gap: 16px;
 
 
129
  }
130
 
131
- .control-row {
132
- display: flex;
133
- gap: 10px;
134
- align-items: center;
 
135
  }
136
 
137
  .btn-primary {
138
  background: var(--primary) !important;
139
  border: none !important;
140
- color: white !important;
141
  }
142
 
143
  .btn-primary:hover {
@@ -145,76 +117,17 @@ css = """
145
  }
146
 
147
  .examples {
148
- display: grid;
149
- grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
150
- gap: 12px;
151
  margin-top: 20px;
152
  }
153
-
154
- .example-prompt {
155
- background: var(--secondary);
156
- padding: 12px;
157
- border-radius: 8px;
158
- cursor: pointer;
159
- transition: all 0.2s;
160
- }
161
-
162
- .example-prompt:hover {
163
- transform: translateY(-2px);
164
- box-shadow: var(--shadow);
165
- }
166
-
167
- .theme-toggle {
168
- position: absolute;
169
- top: 20px;
170
- right: 20px;
171
- background: var(--secondary);
172
- border: none;
173
- border-radius: 50%;
174
- width: 40px;
175
- height: 40px;
176
- display: flex;
177
- align-items: center;
178
- justify-content: center;
179
- cursor: pointer;
180
- }
181
-
182
- @media (max-width: 768px) {
183
- .advanced-settings .form {
184
- grid-template-columns: 1fr;
185
- }
186
- }
187
  """
188
 
189
- js = """
190
- function toggleTheme() {
191
- const body = document.body;
192
- body.classList.toggle('dark');
193
- localStorage.setItem('gradio-theme', body.classList.contains('dark') ? 'dark' : 'light');
194
- }
195
-
196
- document.addEventListener('DOMContentLoaded', () => {
197
- const savedTheme = localStorage.getItem('gradio-theme') || 'light';
198
- if (savedTheme === 'dark') {
199
- document.body.classList.add('dark');
200
- }
201
-
202
- const themeToggle = document.createElement('button');
203
- themeToggle.className = 'theme-toggle';
204
- themeToggle.innerHTML = savedTheme === 'dark' ? '☀️' : '🌙';
205
- themeToggle.onclick = toggleTheme;
206
- document.body.appendChild(themeToggle);
207
- });
208
- """
209
-
210
- with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
211
  with gr.Column(elem_id="col-container"):
212
  with gr.Column(visible=True) as header:
213
  gr.Markdown(
214
  """
215
  <div class="header">
216
- <h1>✨ AI Image Generator</h1>
217
- <p>Transform your text into stunning images with SDXL Turbo</p>
218
  </div>
219
  """,
220
  elem_classes="header"
@@ -222,118 +135,74 @@ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
222
 
223
  with gr.Column(elem_classes="prompt-container"):
224
  with gr.Row():
225
- prompt = gr.Textbox(
226
- label="",
227
  show_label=False,
228
- max_lines=2,
229
- placeholder="Describe the image you want to generate...",
230
  container=False,
231
- scale=5
232
- )
233
- run_button = gr.Button(
234
- "Generate",
235
- scale=1,
236
- variant="primary",
237
- elem_classes="btn-primary"
238
  )
 
239
 
240
  with gr.Column(elem_classes="result-container"):
241
- result = gr.Image(
242
- label="Generated Image",
243
- show_label=False,
244
- height=500
 
 
 
 
 
 
 
 
 
 
 
 
245
  )
 
 
 
246
  with gr.Row():
247
- seed_info = gr.Textbox(
248
- label="Seed",
249
- interactive=False
 
 
 
250
  )
251
- time_info = gr.Textbox(
252
- label="Generation Time",
253
- interactive=False
 
 
 
 
254
  )
255
 
256
- with gr.Accordion("🛠️ Advanced Settings", open=False, elem_classes="advanced-settings"):
257
- with gr.Column(elem_classes="form"):
258
- with gr.Row():
259
- negative_prompt = gr.Textbox(
260
- label="Negative Prompt",
261
- max_lines=1,
262
- placeholder="What you don't want to see in the image",
263
- )
264
-
265
- with gr.Row():
266
- with gr.Column():
267
- seed = gr.Slider(
268
- label="Seed",
269
- minimum=0,
270
- maximum=MAX_SEED,
271
- step=1,
272
- value=0,
273
- )
274
- randomize_seed = gr.Checkbox(
275
- label="Randomize seed",
276
- value=True,
277
- )
278
-
279
- with gr.Column():
280
- width = gr.Slider(
281
- label="Width",
282
- minimum=256,
283
- maximum=MAX_IMAGE_SIZE,
284
- step=32,
285
- value=1024,
286
- )
287
- height = gr.Slider(
288
- label="Height",
289
- minimum=256,
290
- maximum=MAX_IMAGE_SIZE,
291
- step=32,
292
- value=1024,
293
- )
294
-
295
- with gr.Row():
296
- guidance_scale = gr.Slider(
297
- label="Guidance Scale",
298
- minimum=0.0,
299
- maximum=10.0,
300
- step=0.1,
301
- value=0.0,
302
- )
303
- num_inference_steps = gr.Slider(
304
- label="Inference Steps",
305
- minimum=1,
306
- maximum=50,
307
- step=1,
308
- value=2,
309
- )
310
-
311
- gr.Markdown("### Example Prompts")
312
- with gr.Row(elem_classes="examples"):
313
- for example in examples:
314
- with gr.Column(min_width=200):
315
- gr.Examples(
316
- examples=[[example[0], example[1], example[2]]],
317
- inputs=[prompt, width, height],
318
- label="",
319
- examples_per_page=20
320
- )
321
-
322
- run_button.click(
323
- fn=infer,
324
- inputs=[
325
- prompt,
326
- negative_prompt,
327
- seed,
328
- randomize_seed,
329
- width,
330
- height,
331
- guidance_scale,
332
- num_inference_steps,
333
- ],
334
- outputs=[result, seed_info, time_info],
335
- )
336
- prompt.submit(
337
  fn=infer,
338
  inputs=[
339
  prompt,
@@ -345,8 +214,8 @@ with gr.Blocks(css=css, js=js, theme=gr.themes.Soft()) as demo:
345
  guidance_scale,
346
  num_inference_steps,
347
  ],
348
- outputs=[result, seed_info, time_info],
349
  )
350
 
351
  if __name__ == "__main__":
352
- demo.queue(api_open=False).launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+
5
+ # import spaces #[uncomment to use ZeroGPU]
6
  from diffusers import DiffusionPipeline
7
  import torch
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
+
12
+ if torch.cuda.is_available():
13
+ torch_dtype = torch.float16
14
+ else:
15
+ torch_dtype = torch.float32
16
+
17
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
+ pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
23
+
24
+ # @spaces.GPU #[uncomment to use ZeroGPU]
25
  def infer(
26
  prompt,
27
  negative_prompt,
 
33
  num_inference_steps,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
+ if randomize_seed:
37
+ seed = random.randint(0, MAX_SEED)
38
+
39
+ generator = torch.Generator().manual_seed(seed)
40
+
41
+ image = pipe(
42
+ prompt=prompt,
43
+ negative_prompt=negative_prompt,
44
+ guidance_scale=guidance_scale,
45
+ num_inference_steps=num_inference_steps,
46
+ width=width,
47
+ height=height,
48
+ generator=generator,
49
+ ).images[0]
50
+
51
+ return image, seed
52
+
 
 
 
 
 
 
 
 
53
 
54
  examples = [
55
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
+ "An astronaut riding a green horse",
57
+ "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css = """
 
66
  --shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
67
  }
68
 
 
 
 
 
 
 
 
 
69
  #col-container {
70
  margin: 0 auto;
71
  max-width: 800px;
 
84
  margin-bottom: 10px;
85
  }
86
 
87
+ .prompt-container {
88
+ background: white;
 
 
 
 
 
89
  border-radius: 12px;
90
  padding: 20px;
91
  box-shadow: var(--shadow);
92
  margin-bottom: 20px;
93
  }
94
 
95
+ .result-container {
96
+ background: white;
97
+ border-radius: 12px;
98
+ padding: 20px;
99
+ box-shadow: var(--shadow);
100
+ margin-bottom: 20px;
101
  }
102
 
103
+ .advanced-settings {
104
+ background: white;
105
+ border-radius: 12px;
106
+ padding: 20px;
107
+ box-shadow: var(--shadow);
108
  }
109
 
110
  .btn-primary {
111
  background: var(--primary) !important;
112
  border: none !important;
 
113
  }
114
 
115
  .btn-primary:hover {
 
117
  }
118
 
119
  .examples {
 
 
 
120
  margin-top: 20px;
121
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  """
123
 
124
+ with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Column(elem_id="col-container"):
126
  with gr.Column(visible=True) as header:
127
  gr.Markdown(
128
  """
129
  <div class="header">
130
+ <h1>Text-to-Image Generator</h1>
 
131
  </div>
132
  """,
133
  elem_classes="header"
 
135
 
136
  with gr.Column(elem_classes="prompt-container"):
137
  with gr.Row():
138
+ prompt = gr.Text(
139
+ label="Prompt",
140
  show_label=False,
141
+ max_lines=1,
142
+ placeholder="Enter your prompt",
143
  container=False,
 
 
 
 
 
 
 
144
  )
145
+ run_button = gr.Button("Run", scale=0, variant="primary", elem_classes="btn-primary")
146
 
147
  with gr.Column(elem_classes="result-container"):
148
+ result = gr.Image(label="Result", show_label=False)
149
+
150
+ with gr.Accordion("Advanced Settings", open=False, elem_classes="advanced-settings"):
151
+ negative_prompt = gr.Text(
152
+ label="Negative prompt",
153
+ max_lines=1,
154
+ placeholder="Enter a negative prompt",
155
+ visible=False,
156
+ )
157
+
158
+ seed = gr.Slider(
159
+ label="Seed",
160
+ minimum=0,
161
+ maximum=MAX_SEED,
162
+ step=1,
163
+ value=0,
164
  )
165
+
166
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
167
+
168
  with gr.Row():
169
+ width = gr.Slider(
170
+ label="Width",
171
+ minimum=256,
172
+ maximum=MAX_IMAGE_SIZE,
173
+ step=32,
174
+ value=1024,
175
  )
176
+
177
+ height = gr.Slider(
178
+ label="Height",
179
+ minimum=256,
180
+ maximum=MAX_IMAGE_SIZE,
181
+ step=32,
182
+ value=1024,
183
  )
184
 
185
+ with gr.Row():
186
+ guidance_scale = gr.Slider(
187
+ label="Guidance scale",
188
+ minimum=0.0,
189
+ maximum=10.0,
190
+ step=0.1,
191
+ value=0.0,
192
+ )
193
+
194
+ num_inference_steps = gr.Slider(
195
+ label="Number of inference steps",
196
+ minimum=1,
197
+ maximum=50,
198
+ step=1,
199
+ value=2,
200
+ )
201
+
202
+ gr.Examples(examples=examples, inputs=[prompt], elem_classes="examples")
203
+
204
+ gr.on(
205
+ triggers=[run_button.click, prompt.submit],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  fn=infer,
207
  inputs=[
208
  prompt,
 
214
  guidance_scale,
215
  num_inference_steps,
216
  ],
217
+ outputs=[result, seed],
218
  )
219
 
220
  if __name__ == "__main__":
221
+ demo.launch()