RageshAntony commited on
Commit
a8c6b1a
·
verified ·
1 Parent(s): 02919e4

added deep cleanup

Browse files
Files changed (1) hide show
  1. app.py +133 -28
app.py CHANGED
@@ -8,6 +8,13 @@ from diffusers import (
8
  LuminaText2ImgPipeline
9
  )
10
  import spaces
 
 
 
 
 
 
 
11
 
12
  # Constants
13
  MAX_SEED = np.iinfo(np.int32).max
@@ -47,18 +54,85 @@ MODEL_CONFIGS = {
47
  }
48
  }
49
 
50
- # Initialize model pipelines
51
  pipes = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def load_pipeline(model_name):
 
 
54
  config = MODEL_CONFIGS[model_name]
 
55
  pipe = config["pipeline_class"].from_pretrained(
56
  config["repo_id"],
57
  torch_dtype=TORCH_DTYPE
58
  )
59
  pipe = pipe.to(DEVICE)
 
60
  if hasattr(pipe, 'enable_model_cpu_offload'):
61
  pipe.enable_model_cpu_offload()
 
 
 
 
62
  return pipe
63
 
64
  @spaces.GPU(duration=180)
@@ -74,33 +148,48 @@ def generate_image(
74
  num_inference_steps=40,
75
  progress=gr.Progress(track_tqdm=True)
76
  ):
77
- progress(0, desc=f"Loading {model_name} model...")
78
-
79
- # Load model if not already loaded
80
- if model_name not in pipes:
81
- pipes[model_name] = load_pipeline(model_name)
82
-
83
- pipe = pipes[model_name]
84
-
85
- if randomize_seed:
86
- seed = random.randint(0, MAX_SEED)
87
-
88
- generator = torch.Generator(DEVICE).manual_seed(seed)
89
-
90
- progress(0.3, desc=f"Generating image with {model_name}...")
91
-
92
- image = pipe(
93
- prompt=prompt,
94
- negative_prompt=negative_prompt,
95
- guidance_scale=guidance_scale,
96
- num_inference_steps=num_inference_steps,
97
- width=width,
98
- height=height,
99
- generator=generator,
100
- ).images[0]
101
-
102
- progress(1.0, desc=f"Generation complete with {model_name}")
103
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Gradio Interface
106
  css = """
@@ -173,6 +262,9 @@ with gr.Blocks(css=css) as demo:
173
  value=40,
174
  )
175
 
 
 
 
176
  # Create tabs for each model
177
  with gr.Tabs() as tabs:
178
  results = {}
@@ -188,6 +280,14 @@ with gr.Blocks(css=css) as demo:
188
  ]
189
  gr.Examples(examples=examples, inputs=[prompt])
190
 
 
 
 
 
 
 
 
 
191
  # Handle generation for each model
192
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
193
  outputs = []
@@ -199,9 +299,14 @@ with gr.Blocks(css=css) as demo:
199
  num_inference_steps, progress
200
  )
201
  outputs.extend([image, used_seed])
 
 
 
 
202
  except Exception as e:
203
  outputs.extend([None, None])
204
  print(f"Error generating with {model_name}: {str(e)}")
 
205
  return outputs
206
 
207
  # Set up the generation trigger
 
8
  LuminaText2ImgPipeline
9
  )
10
  import spaces
11
+ import gc
12
+ import os
13
+ import psutil
14
+ import threading
15
+ from pathlib import Path
16
+ import shutil
17
+ import time
18
 
19
  # Constants
20
  MAX_SEED = np.iinfo(np.int32).max
 
54
  }
55
  }
56
 
57
+ # Dictionary to store model pipelines
58
  pipes = {}
59
+ model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
60
+
61
+ def get_process_memory():
62
+ """Get memory usage of current process in GB"""
63
+ process = psutil.Process(os.getpid())
64
+ return process.memory_info().rss / 1024 / 1024 / 1024
65
+
66
+ def clear_torch_cache():
67
+ """Clear PyTorch's CUDA cache"""
68
+ if torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+ torch.cuda.ipc_collect()
71
+
72
+ def remove_cache_dir(model_name):
73
+ """Remove the model's cache directory"""
74
+ cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/', '--')
75
+ if cache_dir.exists():
76
+ shutil.rmtree(cache_dir, ignore_errors=True)
77
+
78
+ def deep_cleanup(model_name, pipe):
79
+ """Perform deep cleanup of model resources"""
80
+ try:
81
+ # 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
82
+ if hasattr(pipe, 'to'):
83
+ pipe.to('cpu')
84
+
85
+ # 2. Delete all model components explicitly
86
+ for attr_name in list(pipe.__dict__.keys()):
87
+ if hasattr(pipe, attr_name):
88
+ delattr(pipe, attr_name)
89
+
90
+ # 3. Remove from pipes dictionary
91
+ if model_name in pipes:
92
+ del pipes[model_name]
93
+
94
+ # 4. Clear CUDA cache
95
+ clear_torch_cache()
96
+
97
+ # 5. Run garbage collection multiple times
98
+ for _ in range(3):
99
+ gc.collect()
100
+
101
+ # 6. Remove cached files
102
+ remove_cache_dir(model_name)
103
+
104
+ # 7. Additional CUDA cleanup if available
105
+ if torch.cuda.is_available():
106
+ torch.cuda.synchronize()
107
+
108
+ # 8. Wait a small amount of time to ensure cleanup
109
+ time.sleep(1)
110
+
111
+ except Exception as e:
112
+ print(f"Error during cleanup of {model_name}: {str(e)}")
113
+
114
+ finally:
115
+ # Final garbage collection
116
+ gc.collect()
117
+ clear_torch_cache()
118
 
119
  def load_pipeline(model_name):
120
+ """Load model pipeline with memory tracking"""
121
+ initial_memory = get_process_memory()
122
  config = MODEL_CONFIGS[model_name]
123
+
124
  pipe = config["pipeline_class"].from_pretrained(
125
  config["repo_id"],
126
  torch_dtype=TORCH_DTYPE
127
  )
128
  pipe = pipe.to(DEVICE)
129
+
130
  if hasattr(pipe, 'enable_model_cpu_offload'):
131
  pipe.enable_model_cpu_offload()
132
+
133
+ final_memory = get_process_memory()
134
+ print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")
135
+
136
  return pipe
137
 
138
  @spaces.GPU(duration=180)
 
148
  num_inference_steps=40,
149
  progress=gr.Progress(track_tqdm=True)
150
  ):
151
+ with model_locks[model_name]:
152
+ try:
153
+ progress(0, desc=f"Loading {model_name} model...")
154
+
155
+ # Load model if not already loaded
156
+ if model_name not in pipes:
157
+ pipes[model_name] = load_pipeline(model_name)
158
+
159
+ pipe = pipes[model_name]
160
+
161
+ if randomize_seed:
162
+ seed = random.randint(0, MAX_SEED)
163
+
164
+ generator = torch.Generator(DEVICE).manual_seed(seed)
165
+
166
+ progress(0.3, desc=f"Generating image with {model_name}...")
167
+
168
+ # Generate image
169
+ image = pipe(
170
+ prompt=prompt,
171
+ negative_prompt=negative_prompt,
172
+ guidance_scale=guidance_scale,
173
+ num_inference_steps=num_inference_steps,
174
+ width=width,
175
+ height=height,
176
+ generator=generator,
177
+ ).images[0]
178
+
179
+ progress(0.9, desc=f"Cleaning up {model_name} resources...")
180
+
181
+ # Cleanup after generation
182
+ deep_cleanup(model_name, pipe)
183
+
184
+ progress(1.0, desc=f"Generation complete with {model_name}")
185
+ return image, seed
186
+
187
+ except Exception as e:
188
+ print(f"Error with {model_name}: {str(e)}")
189
+ # Ensure cleanup happens even if generation fails
190
+ if model_name in pipes:
191
+ deep_cleanup(model_name, pipes[model_name])
192
+ raise e
193
 
194
  # Gradio Interface
195
  css = """
 
262
  value=40,
263
  )
264
 
265
+ # Memory usage indicator
266
+ memory_indicator = gr.Markdown("Current memory usage: 0 GB")
267
+
268
  # Create tabs for each model
269
  with gr.Tabs() as tabs:
270
  results = {}
 
280
  ]
281
  gr.Examples(examples=examples, inputs=[prompt])
282
 
283
+ def update_memory_usage():
284
+ """Update memory usage display"""
285
+ memory_gb = get_process_memory()
286
+ if torch.cuda.is_available():
287
+ cuda_memory_gb = torch.cuda.memory_allocated() / 1024 / 1024 / 1024
288
+ return f"Current memory usage: System RAM: {memory_gb:.2f} GB, CUDA: {cuda_memory_gb:.2f} GB"
289
+ return f"Current memory usage: System RAM: {memory_gb:.2f} GB"
290
+
291
  # Handle generation for each model
292
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
293
  outputs = []
 
299
  num_inference_steps, progress
300
  )
301
  outputs.extend([image, used_seed])
302
+
303
+ # Update memory usage after each model
304
+ memory_indicator.update(update_memory_usage())
305
+
306
  except Exception as e:
307
  outputs.extend([None, None])
308
  print(f"Error generating with {model_name}: {str(e)}")
309
+
310
  return outputs
311
 
312
  # Set up the generation trigger