RageshAntony commited on
Commit
fe79ce9
·
verified ·
1 Parent(s): a69d563
Files changed (1) hide show
  1. app.py +102 -159
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
  import torch
 
 
5
  from diffusers import (
6
  DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline,
7
  AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
@@ -18,15 +20,20 @@ import time
18
  import glob
19
  from datetime import datetime
20
  from PIL import Image
 
 
21
 
22
  # Constants
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
25
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
- TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
27
  OUTPUT_DIR = "generated_images"
28
  os.makedirs(OUTPUT_DIR, exist_ok=True)
29
 
 
 
 
 
30
  # Model configurations
31
  MODEL_CONFIGS = {
32
  "FLUX": {
@@ -37,94 +44,44 @@ MODEL_CONFIGS = {
37
  "repo_id": "stabilityai/stable-diffusion-3.5-large",
38
  "pipeline_class": StableDiffusion3Pipeline
39
  }
40
-
41
  }
42
 
43
- # Dictionary to store model pipelines
44
- pipes = {}
45
- model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
46
-
47
- def get_process_memory():
48
- """Get memory usage of current process in GB"""
49
- process = psutil.Process(os.getpid())
50
- return process.memory_info().rss / 1024 / 1024 / 1024
51
-
52
- def clear_torch_cache():
53
- """Clear PyTorch's CUDA cache"""
54
- if torch.cuda.is_available():
55
- torch.cuda.empty_cache()
56
- torch.cuda.ipc_collect()
57
-
58
- def remove_cache_dir(model_name):
59
- """Remove the model's cache directory"""
60
- cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--')
61
- if cache_dir.exists():
62
- shutil.rmtree(cache_dir, ignore_errors=True)
63
-
64
- def deep_cleanup(model_name, pipe):
65
- """Perform deep cleanup of model resources"""
66
- try:
67
- # 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
68
- if hasattr(pipe, 'to'):
69
- pipe.to('cpu')
70
-
71
- # 2. Delete all model components explicitly
72
- for attr_name in list(pipe.__dict__.keys()):
73
- if hasattr(pipe, attr_name):
74
- delattr(pipe, attr_name)
75
-
76
- # 3. Remove from pipes dictionary
77
- if model_name in pipes:
78
- del pipes[model_name]
79
-
80
- # 4. Clear CUDA cache
81
- clear_torch_cache()
82
-
83
- # 5. Run garbage collection multiple times
84
- for _ in range(3):
85
- gc.collect()
86
 
87
- # 6. Remove cached files
88
- remove_cache_dir(model_name)
89
-
90
- # 7. Additional CUDA cleanup if available
91
- if torch.cuda.is_available():
92
- torch.cuda.synchronize()
93
 
94
- # 8. Wait a small amount of time to ensure cleanup
95
- time.sleep(1)
 
 
96
 
97
- except Exception as e:
98
- print(f"Error during cleanup of {model_name}: {str(e)}")
99
-
100
- finally:
101
- # Final garbage collection
102
- gc.collect()
103
- clear_torch_cache()
104
 
105
- def load_pipeline(model_name):
106
- """Load model pipeline with memory tracking"""
107
- initial_memory = get_process_memory()
108
  config = MODEL_CONFIGS[model_name]
109
 
110
- pipe = config["pipeline_class"].from_pretrained(
111
- config["repo_id"],
112
- torch_dtype=TORCH_DTYPE
113
- )
114
- pipe = pipe.to(DEVICE)
115
-
116
- if hasattr(pipe, 'enable_model_cpu_offload'):
117
- pipe.enable_model_cpu_offload()
118
-
119
- final_memory = get_process_memory()
120
- print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")
121
 
122
  return pipe
123
 
124
  def save_generated_image(image, model_name, prompt):
125
  """Save generated image with timestamp and model name"""
126
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
127
- # Create sanitized filename from prompt (first 30 chars)
128
  prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
129
  filename = f"{timestamp}_{model_name}_{prompt_part}.png"
130
  filepath = os.path.join(OUTPUT_DIR, filename)
@@ -134,7 +91,7 @@ def save_generated_image(image, model_name, prompt):
134
  def get_generated_images():
135
  """Get list of generated images with their details"""
136
  files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
137
- files.sort(key=os.path.getctime, reverse=True) # Sort by creation time
138
  return [
139
  {
140
  "path": f,
@@ -145,34 +102,25 @@ def get_generated_images():
145
  for f in files
146
  ]
147
 
148
- def generate_image(
149
- model_name,
150
- prompt,
151
- negative_prompt="",
152
- seed=42,
153
- randomize_seed=False,
154
- width=1024,
155
- height=1024,
156
- guidance_scale=4.5,
157
- num_inference_steps=40,
158
- progress=gr.Progress(track_tqdm=True)
159
- ):
160
- with model_locks[model_name]:
161
- try:
162
- #progress(0, desc=f"Loading {model_name} model...")
163
-
164
- if model_name not in pipes:
165
- pipes[model_name] = load_pipeline(model_name)
166
-
167
- pipe = pipes[model_name]
168
-
169
- if randomize_seed:
170
- seed = random.randint(0, MAX_SEED)
171
-
172
- generator = torch.Generator(DEVICE).manual_seed(seed)
173
- print(f"Generating image with {model_name}...")
174
- #progress(0.3, desc=f"Generating image with {model_name}...")
175
-
176
  image = pipe(
177
  prompt=prompt,
178
  negative_prompt=negative_prompt,
@@ -182,21 +130,52 @@ def generate_image(
182
  height=height,
183
  generator=generator,
184
  ).images[0]
185
-
186
- filepath = save_generated_image(image, model_name, prompt)
187
- print(f"Saved image to: {filepath}")
188
-
189
- #progress(0.9, desc=f"Cleaning up {model_name} resources...")
190
- #deep_cleanup(model_name, pipe)
191
-
192
- #progress(1.0, desc=f"Generation complete with {model_name}")
193
- return image, seed
194
-
195
- except Exception as e:
196
- print(f"Error with {model_name}: {str(e)}")
197
- if model_name in pipes:
198
- deep_cleanup(model_name, pipes[model_name])
199
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # Gradio Interface
202
  css = """
@@ -208,7 +187,7 @@ css = """
208
 
209
  with gr.Blocks(css=css) as demo:
210
  with gr.Column(elem_id="col-container"):
211
- gr.Markdown("# Multi-Model Image Generation")
212
 
213
  with gr.Row():
214
  prompt = gr.Text(
@@ -269,8 +248,6 @@ with gr.Blocks(css=css) as demo:
269
  value=40,
270
  )
271
 
272
- memory_indicator = gr.Markdown("Current memory usage: 0 GB")
273
-
274
  with gr.Row():
275
  with gr.Column(scale=2):
276
  with gr.Tabs() as tabs:
@@ -291,7 +268,6 @@ with gr.Blocks(css=css) as demo:
291
  height=400
292
  )
293
  refresh_button = gr.Button("Refresh Gallery")
294
-
295
 
296
  def update_gallery():
297
  """Update the file gallery"""
@@ -301,41 +277,6 @@ with gr.Blocks(css=css) as demo:
301
  for f in files
302
  ]
303
 
304
- @spaces.GPU(duration=600)
305
- def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
306
- outputs = [None] * (len(MODEL_CONFIGS) * 2)
307
- for idx, model_name in enumerate(MODEL_CONFIGS.keys()):
308
- try:
309
- # Display progress for the specific model
310
- #progress(0, desc=f"Starting generation for {model_name}...")
311
- print(f"IMAGE GENERATING {model_name} ")
312
- image, used_seed = generate_image(
313
- model_name, prompt, negative_prompt, seed,
314
- randomize_seed, width, height, guidance_scale,
315
- num_inference_steps, progress
316
- )
317
- print(f"IMAGE GENERATIED {model_name} ")
318
- # Update the respective model's tab with the generated image
319
- #results[model_name].update(image)
320
- #seeds[model_name].update(used_seed)
321
- outputs[idx * 2] = image # Image slot
322
- outputs[idx * 2 + 1] = seed # Seed slot
323
- #outputs.extend([image, used_seed])
324
- # Add intermediate results to progress * (len(all_outputs) - len(all_outputs))
325
- print("YELID")
326
- yield outputs + [None]
327
-
328
-
329
- except Exception as e:
330
- print(f"Error generating with {model_name}: {str(e)}")
331
- outputs[idx * 2] = None
332
- outputs[idx * 2 + 1] = None
333
-
334
- # Update the gallery after generation
335
- gallery_images = update_gallery()
336
- #file_gallery.update(value=gallery_images)
337
- return outputs
338
-
339
  output_components = []
340
  for model_name in MODEL_CONFIGS.keys():
341
  output_components.extend([results[model_name], seeds[model_name]])
@@ -368,4 +309,6 @@ with gr.Blocks(css=css) as demo:
368
  )
369
 
370
  if __name__ == "__main__":
 
 
371
  demo.launch()
 
2
  import numpy as np
3
  import random
4
  import torch
5
+ import torch.multiprocessing as mp
6
+ from torch.cuda.amp import autocast
7
  from diffusers import (
8
  DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline,
9
  AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
 
20
  import glob
21
  from datetime import datetime
22
  from PIL import Image
23
+ from queue import Queue
24
+ from concurrent.futures import ThreadPoolExecutor, as_completed
25
 
26
  # Constants
27
  MAX_SEED = np.iinfo(np.int32).max
28
  MAX_IMAGE_SIZE = 1024
29
+ TORCH_DTYPE = torch.bfloat16
 
30
  OUTPUT_DIR = "generated_images"
31
  os.makedirs(OUTPUT_DIR, exist_ok=True)
32
 
33
+ # Get available GPU devices
34
+ AVAILABLE_GPUS = list(range(torch.cuda.device_count()))
35
+ print(f"Available GPUs: {AVAILABLE_GPUS}")
36
+
37
  # Model configurations
38
  MODEL_CONFIGS = {
39
  "FLUX": {
 
44
  "repo_id": "stabilityai/stable-diffusion-3.5-large",
45
  "pipeline_class": StableDiffusion3Pipeline
46
  }
 
47
  }
48
 
49
+ # GPU allocation queue and model cache
50
+ gpu_queue = Queue()
51
+ for gpu_id in AVAILABLE_GPUS:
52
+ gpu_queue.put(gpu_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ model_cache = {}
55
+ model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
 
 
 
 
56
 
57
+ def get_next_available_gpu():
58
+ """Get the next available GPU from the queue"""
59
+ gpu_id = gpu_queue.get()
60
+ return gpu_id
61
 
62
+ def release_gpu(gpu_id):
63
+ """Release GPU back to the queue"""
64
+ gpu_queue.put(gpu_id)
 
 
 
 
65
 
66
+ def load_pipeline_on_gpu(model_name, gpu_id):
67
+ """Load model pipeline on specific GPU with memory tracking"""
 
68
  config = MODEL_CONFIGS[model_name]
69
 
70
+ with torch.cuda.device(gpu_id):
71
+ pipe = config["pipeline_class"].from_pretrained(
72
+ config["repo_id"],
73
+ torch_dtype=TORCH_DTYPE
74
+ )
75
+ pipe = pipe.to(f"cuda:{gpu_id}")
76
+
77
+ if hasattr(pipe, 'enable_model_cpu_offload'):
78
+ pipe.enable_model_cpu_offload()
 
 
79
 
80
  return pipe
81
 
82
  def save_generated_image(image, model_name, prompt):
83
  """Save generated image with timestamp and model name"""
84
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
 
85
  prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
86
  filename = f"{timestamp}_{model_name}_{prompt_part}.png"
87
  filepath = os.path.join(OUTPUT_DIR, filename)
 
91
  def get_generated_images():
92
  """Get list of generated images with their details"""
93
  files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
94
+ files.sort(key=os.path.getctime, reverse=True)
95
  return [
96
  {
97
  "path": f,
 
102
  for f in files
103
  ]
104
 
105
+ def generate_image_on_gpu(args):
106
+ """Generate image on specific GPU"""
107
+ model_name, prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps = args
108
+
109
+ try:
110
+ gpu_id = get_next_available_gpu()
111
+ print(f"Generating {model_name} on GPU {gpu_id}")
112
+
113
+ # Load or get cached pipeline
114
+ cache_key = f"{model_name}_{gpu_id}"
115
+ if cache_key not in model_cache:
116
+ with model_locks[model_name]:
117
+ model_cache[cache_key] = load_pipeline_on_gpu(model_name, gpu_id)
118
+
119
+ pipe = model_cache[cache_key]
120
+
121
+ # Generate image
122
+ with torch.cuda.device(gpu_id), autocast():
123
+ generator = torch.Generator(f"cuda:{gpu_id}").manual_seed(seed)
 
 
 
 
 
 
 
 
 
124
  image = pipe(
125
  prompt=prompt,
126
  negative_prompt=negative_prompt,
 
130
  height=height,
131
  generator=generator,
132
  ).images[0]
133
+
134
+ filepath = save_generated_image(image, model_name, prompt)
135
+ print(f"Saved image from {model_name} to: {filepath}")
136
+
137
+ release_gpu(gpu_id)
138
+ return image, seed
139
+
140
+ except Exception as e:
141
+ print(f"Error with {model_name} on GPU {gpu_id}: {str(e)}")
142
+ release_gpu(gpu_id)
143
+ raise e
144
+
145
+ def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
146
+ outputs = [None] * (len(MODEL_CONFIGS) * 2)
147
+
148
+ # Prepare generation tasks
149
+ tasks = []
150
+ for model_name in MODEL_CONFIGS.keys():
151
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else seed
152
+ tasks.append((
153
+ model_name, prompt, negative_prompt, current_seed,
154
+ width, height, guidance_scale, num_inference_steps
155
+ ))
156
+
157
+ # Run generation in parallel using thread pool
158
+ with ThreadPoolExecutor(max_workers=len(AVAILABLE_GPUS)) as executor:
159
+ future_to_model = {
160
+ executor.submit(generate_image_on_gpu, task): idx
161
+ for idx, task in enumerate(tasks)
162
+ }
163
+
164
+ for future in as_completed(future_to_model):
165
+ idx = future_to_model[future]
166
+ try:
167
+ image, used_seed = future.result()
168
+ outputs[idx * 2] = image
169
+ outputs[idx * 2 + 1] = used_seed
170
+ yield outputs + [None]
171
+ except Exception as e:
172
+ print(f"Generation failed for model {idx}: {str(e)}")
173
+ outputs[idx * 2] = None
174
+ outputs[idx * 2 + 1] = None
175
+
176
+ # Update gallery after all generations complete
177
+ gallery_images = update_gallery()
178
+ return outputs
179
 
180
  # Gradio Interface
181
  css = """
 
187
 
188
  with gr.Blocks(css=css) as demo:
189
  with gr.Column(elem_id="col-container"):
190
+ gr.Markdown(f"# Multi-GPU Image Generation ({len(AVAILABLE_GPUS)} GPUs Available)")
191
 
192
  with gr.Row():
193
  prompt = gr.Text(
 
248
  value=40,
249
  )
250
 
 
 
251
  with gr.Row():
252
  with gr.Column(scale=2):
253
  with gr.Tabs() as tabs:
 
268
  height=400
269
  )
270
  refresh_button = gr.Button("Refresh Gallery")
 
271
 
272
  def update_gallery():
273
  """Update the file gallery"""
 
277
  for f in files
278
  ]
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  output_components = []
281
  for model_name in MODEL_CONFIGS.keys():
282
  output_components.extend([results[model_name], seeds[model_name]])
 
309
  )
310
 
311
  if __name__ == "__main__":
312
+ # Initialize multiprocessing for PyTorch
313
+ mp.set_start_method('spawn', force=True)
314
  demo.launch()