RageshAntony commited on
Commit
80e6066
·
verified ·
1 Parent(s): 0e6b892
Files changed (1) hide show
  1. app.py +75 -28
app.py CHANGED
@@ -15,12 +15,17 @@ import threading
15
  from pathlib import Path
16
  import shutil
17
  import time
 
 
 
18
 
19
  # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
  TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
24
 
25
  # Model configurations
26
  MODEL_CONFIGS = {
@@ -135,7 +140,30 @@ def load_pipeline(model_name):
135
 
136
  return pipe
137
 
138
- #@spaces.GPU(duration=180)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def generate_image(
140
  model_name,
141
  prompt,
@@ -152,7 +180,6 @@ def generate_image(
152
  try:
153
  progress(0, desc=f"Loading {model_name} model...")
154
 
155
- # Load model if not already loaded
156
  if model_name not in pipes:
157
  pipes[model_name] = load_pipeline(model_name)
158
 
@@ -165,7 +192,6 @@ def generate_image(
165
  print(f"Generating image with {model_name}...")
166
  progress(0.3, desc=f"Generating image with {model_name}...")
167
 
168
- # Generate image
169
  image = pipe(
170
  prompt=prompt,
171
  negative_prompt=negative_prompt,
@@ -176,9 +202,10 @@ def generate_image(
176
  generator=generator,
177
  ).images[0]
178
 
179
- progress(0.9, desc=f"Cleaning up {model_name} resources...")
 
180
 
181
- # Cleanup after generation
182
  deep_cleanup(model_name, pipe)
183
 
184
  progress(1.0, desc=f"Generation complete with {model_name}")
@@ -186,7 +213,6 @@ def generate_image(
186
 
187
  except Exception as e:
188
  print(f"Error with {model_name}: {str(e)}")
189
- # Ensure cleanup happens even if generation fails
190
  if model_name in pipes:
191
  deep_cleanup(model_name, pipes[model_name])
192
  raise e
@@ -198,7 +224,6 @@ css = """
198
  max-width: 1024px;
199
  }
200
  """
201
- #run_test_safe.zerogpu = True
202
 
203
  with gr.Blocks(css=css) as demo:
204
  with gr.Column(elem_id="col-container"):
@@ -263,17 +288,28 @@ with gr.Blocks(css=css) as demo:
263
  value=40,
264
  )
265
 
266
- # Memory usage indicator
267
  memory_indicator = gr.Markdown("Current memory usage: 0 GB")
268
 
269
- # Create tabs for each model
270
- with gr.Tabs() as tabs:
271
- results = {}
272
- seeds = {}
273
- for model_name in MODEL_CONFIGS.keys():
274
- with gr.Tab(model_name):
275
- results[model_name] = gr.Image(label=f"{model_name} Result")
276
- seeds[model_name] = gr.Number(label="Seed used", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
  examples = [
279
  "A capybara wearing a suit holding a sign that reads Hello World",
@@ -281,15 +317,14 @@ with gr.Blocks(css=css) as demo:
281
  ]
282
  gr.Examples(examples=examples, inputs=[prompt])
283
 
284
- def update_memory_usage():
285
- """Update memory usage display"""
286
- memory_gb = get_process_memory()
287
- if torch.cuda.is_available():
288
- cuda_memory_gb = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
289
- return f"Current memory usage: System RAM: {memory_gb:.2f} GB, CUDA: {cuda_memory_gb:.2f} GB"
290
- return f"Current memory usage: System RAM: {memory_gb:.2f} GB"
291
 
292
- # Handle generation for each model
293
  @spaces.GPU(duration=600)
294
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
295
  outputs = []
@@ -304,16 +339,16 @@ with gr.Blocks(css=css) as demo:
304
  print(f"IMAGE GENERATED {model_name} {update_memory_usage()}")
305
  outputs.extend([image, used_seed])
306
 
307
- # Update memory usage after each model
308
- #memory_indicator.update(update_memory_usage())
309
-
310
  except Exception as e:
311
  outputs.extend([None, None])
312
  print(f"Error generating with {model_name}: {str(e)}")
313
 
 
 
 
 
314
  return outputs
315
 
316
- # Set up the generation trigger
317
  output_components = []
318
  for model_name in MODEL_CONFIGS.keys():
319
  output_components.extend([results[model_name], seeds[model_name]])
@@ -333,5 +368,17 @@ with gr.Blocks(css=css) as demo:
333
  outputs=output_components,
334
  )
335
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  if __name__ == "__main__":
337
  demo.launch()
 
15
  from pathlib import Path
16
  import shutil
17
  import time
18
+ import glob
19
+ from datetime import datetime
20
+ from PIL import Image
21
 
22
  # Constants
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
  TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
27
+ OUTPUT_DIR = "generated_images"
28
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
29
 
30
  # Model configurations
31
  MODEL_CONFIGS = {
 
140
 
141
  return pipe
142
 
143
+ def save_generated_image(image, model_name, prompt):
144
+ """Save generated image with timestamp and model name"""
145
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
146
+ # Create sanitized filename from prompt (first 30 chars)
147
+ prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
148
+ filename = f"{timestamp}_{model_name}_{prompt_part}.png"
149
+ filepath = os.path.join(OUTPUT_DIR, filename)
150
+ image.save(filepath)
151
+ return filepath
152
+
153
+ def get_generated_images():
154
+ """Get list of generated images with their details"""
155
+ files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
156
+ files.sort(key=os.path.getctime, reverse=True) # Sort by creation time
157
+ return [
158
+ {
159
+ "path": f,
160
+ "name": os.path.basename(f),
161
+ "date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"),
162
+ "size": f"{os.path.getsize(f) / 1024:.1f} KB"
163
+ }
164
+ for f in files
165
+ ]
166
+
167
  def generate_image(
168
  model_name,
169
  prompt,
 
180
  try:
181
  progress(0, desc=f"Loading {model_name} model...")
182
 
 
183
  if model_name not in pipes:
184
  pipes[model_name] = load_pipeline(model_name)
185
 
 
192
  print(f"Generating image with {model_name}...")
193
  progress(0.3, desc=f"Generating image with {model_name}...")
194
 
 
195
  image = pipe(
196
  prompt=prompt,
197
  negative_prompt=negative_prompt,
 
202
  generator=generator,
203
  ).images[0]
204
 
205
+ filepath = save_generated_image(image, model_name, prompt)
206
+ print(f"Saved image to: {filepath}")
207
 
208
+ progress(0.9, desc=f"Cleaning up {model_name} resources...")
209
  deep_cleanup(model_name, pipe)
210
 
211
  progress(1.0, desc=f"Generation complete with {model_name}")
 
213
 
214
  except Exception as e:
215
  print(f"Error with {model_name}: {str(e)}")
 
216
  if model_name in pipes:
217
  deep_cleanup(model_name, pipes[model_name])
218
  raise e
 
224
  max-width: 1024px;
225
  }
226
  """
 
227
 
228
  with gr.Blocks(css=css) as demo:
229
  with gr.Column(elem_id="col-container"):
 
288
  value=40,
289
  )
290
 
 
291
  memory_indicator = gr.Markdown("Current memory usage: 0 GB")
292
 
293
+ with gr.Row():
294
+ with gr.Column(scale=2):
295
+ with gr.Tabs() as tabs:
296
+ results = {}
297
+ seeds = {}
298
+ for model_name in MODEL_CONFIGS.keys():
299
+ with gr.Tab(model_name):
300
+ results[model_name] = gr.Image(label=f"{model_name} Result")
301
+ seeds[model_name] = gr.Number(label="Seed used", visible=False)
302
+
303
+ with gr.Column(scale=1):
304
+ gr.Markdown("### Generated Images")
305
+ file_gallery = gr.Gallery(
306
+ label="Generated Images",
307
+ show_label=False,
308
+ elem_id="file_gallery",
309
+ columns=2,
310
+ height=400
311
+ )
312
+ refresh_button = gr.Button("Refresh Gallery")
313
 
314
  examples = [
315
  "A capybara wearing a suit holding a sign that reads Hello World",
 
317
  ]
318
  gr.Examples(examples=examples, inputs=[prompt])
319
 
320
+ def update_gallery():
321
+ """Update the file gallery"""
322
+ files = get_generated_images()
323
+ return [
324
+ (f["path"], f"{f['name']}\n{f['date']}")
325
+ for f in files
326
+ ]
327
 
 
328
  @spaces.GPU(duration=600)
329
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
330
  outputs = []
 
339
  print(f"IMAGE GENERATED {model_name} {update_memory_usage()}")
340
  outputs.extend([image, used_seed])
341
 
 
 
 
342
  except Exception as e:
343
  outputs.extend([None, None])
344
  print(f"Error generating with {model_name}: {str(e)}")
345
 
346
+ # Update the gallery after generation
347
+ gallery_images = update_gallery()
348
+ file_gallery.update(value=gallery_images)
349
+
350
  return outputs
351
 
 
352
  output_components = []
353
  for model_name in MODEL_CONFIGS.keys():
354
  output_components.extend([results[model_name], seeds[model_name]])
 
368
  outputs=output_components,
369
  )
370
 
371
+ refresh_button.click(
372
+ fn=update_gallery,
373
+ inputs=[],
374
+ outputs=[file_gallery],
375
+ )
376
+
377
+ demo.load(
378
+ fn=update_gallery,
379
+ inputs=[],
380
+ outputs=[file_gallery],
381
+ )
382
+
383
  if __name__ == "__main__":
384
  demo.launch()