RageshAntony commited on
Commit
e113c61
·
verified ·
1 Parent(s): d11b704

added multi gpu config

Browse files
Files changed (1) hide show
  1. app.py +31 -1
app.py CHANGED
@@ -23,6 +23,36 @@ from PIL import Image
23
  from queue import Queue
24
  from concurrent.futures import ThreadPoolExecutor, as_completed
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Constants
27
  MAX_SEED = np.iinfo(np.int32).max
28
  MAX_IMAGE_SIZE = 1024
@@ -142,7 +172,7 @@ def generate_image_on_gpu(args):
142
  release_gpu(gpu_id)
143
  raise e
144
 
145
- @spaces.GPU(duration=400)
146
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
147
  outputs = [None] * (len(MODEL_CONFIGS) * 2)
148
 
 
23
  from queue import Queue
24
  from concurrent.futures import ThreadPoolExecutor, as_completed
25
 
26
+ from dataclasses import dataclass
27
+ from typing import Optional, List
28
+
29
+ @dataclass
30
+ class MultiGPUConfig:
31
+ count: int = 2 # Number of GPUs to request
32
+ memory: int = 16 # Memory per GPU in GB
33
+ duration: int = 3600 # Duration in seconds
34
+
35
+ class SpacesMultiGPU:
36
+ def __init__(self, config: Optional[MultiGPUConfig] = None):
37
+ self.config = config or MultiGPUConfig()
38
+
39
+ def __call__(self, func):
40
+ # Apply multiple GPU decorators
41
+ decorated_func = func
42
+ for gpu_idx in range(self.config.count):
43
+ decorated_func = spaces.GPU(
44
+ device=gpu_idx, # Specify which GPU to request
45
+ memory=self.config.memory,
46
+ duration=self.config.duration
47
+ )(decorated_func)
48
+ return decorated_func
49
+
50
+ # Example usage in your generation code
51
+ gpu_config = MultiGPUConfig(
52
+ count=2, # Request 2 GPUs
53
+ duration=400 # 1 hour duration
54
+ )
55
+
56
  # Constants
57
  MAX_SEED = np.iinfo(np.int32).max
58
  MAX_IMAGE_SIZE = 1024
 
172
  release_gpu(gpu_id)
173
  raise e
174
 
175
+ @SpacesMultiGPU(gpu_config)
176
  def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress()):
177
  outputs = [None] * (len(MODEL_CONFIGS) * 2)
178