Gemini899 commited on
Commit
bd1ec8b
·
verified ·
1 Parent(s): 1c6e101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -40
app.py CHANGED
@@ -14,36 +14,32 @@ import numpy as np
14
  import torch
15
  from diffusers import FluxImg2ImgPipeline
16
 
17
- # Global pipe variable for lazy loading
18
- pipe = None
19
-
20
  # Use float16 instead of bfloat16 for T4 compatibility
21
  dtype = torch.float16
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- def get_pipe():
25
- global pipe
26
- if pipe is None:
27
- # Set more aggressive memory optimization
28
- torch.cuda.empty_cache()
29
- gc.collect() # Force garbage collection
30
-
31
- pipe = FluxImg2ImgPipeline.from_pretrained(
32
- "black-forest-labs/FLUX.1-schnell",
33
- torch_dtype=torch.float16,
34
- low_cpu_mem_usage=True,
35
- use_safetensors=True
36
- )
37
-
38
- # Enable attention slicing to reduce memory footprint
39
- pipe.enable_attention_slicing(1)
40
-
41
- # Move to device more carefully
42
- if torch.cuda.is_available():
43
- pipe = pipe.to("cuda:0")
44
- else:
45
- pipe = pipe.to("cpu")
46
- return pipe
47
 
48
  def sanitize_prompt(prompt):
49
  # Allow only alphanumeric characters, spaces, and basic punctuation
@@ -80,12 +76,6 @@ def resize_image(image: Image.Image, max_dim: int = 384) -> Image.Image:
80
  image = image.resize((new_w, new_h), Image.LANCZOS)
81
  return image
82
 
83
- # Try to preload the model at startup
84
- def preload_model():
85
- # Skip preloading to avoid memory issues at startup
86
- print("Skipping preload - will load model on first request")
87
- pass
88
-
89
  # Increase the timeout to 4 minutes
90
  @spaces.GPU(duration=740)
91
  def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=2, progress=gr.Progress(track_tqdm=True)):
@@ -94,10 +84,7 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
94
  torch.cuda.empty_cache()
95
  gc.collect()
96
 
97
- progress(5, desc="Loading model")
98
- # Get the model using lazy loading
99
- model = get_pipe()
100
- progress(15, desc="Model loaded")
101
 
102
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
103
  if image is None:
@@ -123,7 +110,7 @@ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step
123
  # Use autocast for better memory efficiency
124
  with torch.cuda.amp.autocast(dtype=torch.float16):
125
  with torch.no_grad():
126
- output = model(
127
  prompt=prompt,
128
  image=image,
129
  generator=generator,
@@ -181,9 +168,6 @@ css="""
181
  }
182
  """
183
 
184
- # Try to preload the model
185
- preload_model()
186
-
187
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
188
  with gr.Column():
189
  gr.HTML(read_file("demo_header.html"))
 
14
  import torch
15
  from diffusers import FluxImg2ImgPipeline
16
 
 
 
 
17
  # Use float16 instead of bfloat16 for T4 compatibility
18
  dtype = torch.float16
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # Initialize the pipe directly during startup
22
+ print("Loading model during startup...")
23
+ torch.cuda.empty_cache()
24
+ gc.collect() # Force garbage collection
25
+
26
+ pipe = FluxImg2ImgPipeline.from_pretrained(
27
+ "black-forest-labs/FLUX.1-schnell",
28
+ torch_dtype=torch.float16,
29
+ low_cpu_mem_usage=True,
30
+ use_safetensors=True
31
+ )
32
+
33
+ # Enable attention slicing to reduce memory footprint
34
+ pipe.enable_attention_slicing(1)
35
+
36
+ # Move to device immediately
37
+ if torch.cuda.is_available():
38
+ pipe = pipe.to("cuda:0")
39
+ else:
40
+ pipe = pipe.to("cpu")
41
+
42
+ print("Model loaded successfully")
 
43
 
44
  def sanitize_prompt(prompt):
45
  # Allow only alphanumeric characters, spaces, and basic punctuation
 
76
  image = image.resize((new_w, new_h), Image.LANCZOS)
77
  return image
78
 
 
 
 
 
 
 
79
  # Increase the timeout to 4 minutes
80
  @spaces.GPU(duration=740)
81
  def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=2, progress=gr.Progress(track_tqdm=True)):
 
84
  torch.cuda.empty_cache()
85
  gc.collect()
86
 
87
+ progress(15, desc="Processing")
 
 
 
88
 
89
  def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
90
  if image is None:
 
110
  # Use autocast for better memory efficiency
111
  with torch.cuda.amp.autocast(dtype=torch.float16):
112
  with torch.no_grad():
113
+ output = pipe(
114
  prompt=prompt,
115
  image=image,
116
  generator=generator,
 
168
  }
169
  """
170
 
 
 
 
171
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
172
  with gr.Column():
173
  gr.HTML(read_file("demo_header.html"))