Gemini899 commited on
Commit
1f71182
·
verified ·
1 Parent(s): f6b6080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -121
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # ---- Imports ----
2
  import logging
3
  import random
4
  import warnings
@@ -7,9 +6,9 @@ import io
7
  import base64
8
  import gradio as gr
9
  import numpy as np
10
- import spaces
11
  import torch
12
- # --- Add this if you want to try Solution 2 later ---
13
  # import torch._dynamo
14
  from diffusers import FluxControlNetModel
15
  from diffusers.pipelines import FluxControlNetPipeline
@@ -18,7 +17,6 @@ from PIL import Image, ImageOps
18
  from huggingface_hub import snapshot_download
19
 
20
  # --- Setup Logging and Device ---
21
- # ... (rest of setup code remains the same) ...
22
  logging.basicConfig(level=logging.INFO)
23
  warnings.filterwarnings("ignore")
24
 
@@ -28,9 +26,13 @@ css = """
28
  """
29
 
30
  if torch.cuda.is_available():
31
- power_device = "GPU"; device = "cuda"; torch_dtype = torch.bfloat16
 
 
32
  else:
33
- power_device = "CPU"; device = "cpu"; torch_dtype = torch.float32
 
 
34
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
35
 
36
  # --- Authentication and Model Download ---
@@ -41,25 +43,36 @@ local_model_dir = flux_model_id.split('/')[-1]
41
  pipe = None
42
 
43
  try:
44
- # ... (model download code remains the same) ...
45
  logging.info(f"Downloading base model: {flux_model_id}")
46
- model_path = snapshot_download(repo_id=flux_model_id, repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir=local_model_dir, token=huggingface_token)
 
 
 
 
 
 
47
  logging.info(f"Base model downloaded/verified in: {model_path}")
48
 
49
  logging.info(f"Loading ControlNet model: {controlnet_model_id}")
50
- controlnet = FluxControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch_dtype).to(device)
 
 
51
  logging.info("ControlNet model loaded.")
52
 
53
  logging.info("Loading FluxControlNetPipeline...")
54
- pipe = FluxControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, torch_dtype=torch_dtype)
 
 
 
 
55
  pipe.to(device)
56
  logging.info("Pipeline loaded and moved to device.")
57
 
58
  # --- OPTIMIZATION: Attempt torch.compile (PyTorch 2.0+) ---
59
  if device == "cuda" and hasattr(torch, "compile"):
60
- # --- TRY THIS FIRST: Change mode to 'default' ---
61
  compile_mode = "default"
62
- # --- Alternative (Solution 2): Uncomment these lines ---
63
  # import torch._dynamo
64
  # torch._dynamo.config.suppress_errors = True
65
  # compile_mode = "max-autotune" # or "default" even with suppress_errors
@@ -69,261 +82,403 @@ try:
69
  try:
70
  pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True)
71
  logging.info("Pipeline transformer compiled successfully.")
72
- # Optional dummy inference run can go here
73
  except Exception as e:
74
  logging.warning(f"torch.compile failed (mode='{compile_mode}'): {e}. Running unoptimized.")
75
- # --- Solution 3: If compilation fails consistently, comment out the compile line above ---
76
  # pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True) # <-- Comment this out
77
  else:
78
  logging.info("torch.compile not available or not on CUDA, skipping compilation.")
79
 
80
- # --- (Optional xformers code would go here if used) ---
81
 
82
  logging.info("Pipeline ready for inference.")
83
 
84
  except Exception as e:
 
85
  logging.error(f"FATAL: Error during model loading or setup: {e}", exc_info=True)
 
86
  print(f"FATAL ERROR DURING MODEL LOAD/SETUP: {e}")
 
87
  raise SystemExit(f"Model loading/setup failed: {e}")
88
 
89
-
90
  # --- Constants ---
91
  MAX_SEED = 2**32 - 1
92
- MAX_PIXEL_BUDGET = 1280 * 1280
93
- INTERNAL_PROCESSING_FACTOR = 4
94
 
95
  # --- Image Processing Function (process_input) ---
96
- # ... (process_input function remains the same) ...
97
  def process_input(input_image):
98
- if input_image is None: raise gr.Error("Input image is missing!")
 
 
99
  try:
 
100
  input_image = ImageOps.exif_transpose(input_image)
101
- if input_image.mode != 'RGB': input_image = input_image.convert('RGB')
 
 
 
102
  w, h = input_image.size
103
- except AttributeError: raise gr.Error("Invalid input image format.")
104
- except Exception as img_err: raise gr.Error(f"Could not process input image: {img_err}")
 
 
 
 
105
 
106
  w_original, h_original = w, h
107
- if w == 0 or h == 0: raise gr.Error("Input image has zero dimensions.")
 
108
 
 
109
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
110
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
111
  target_pixels_internal = target_w_internal * target_h_internal
 
112
  was_resized = False
113
- input_image_to_process = input_image.copy()
114
 
 
115
  if target_pixels_internal > MAX_PIXEL_BUDGET:
 
116
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
117
  current_input_pixels = w * h
 
118
  if current_input_pixels > max_input_pixels:
 
119
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
120
- input_w_resized = max(8, int(w * input_scale_factor))
121
- input_h_resized = max(8, int(h * input_scale_factor))
122
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
123
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
124
- logging.warning(f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate exceeds budget. Resizing input {w}x{h} -> {input_w_resized}x{input_h_resized}.")
125
- gr.Info(f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} -> model generates ~{int(intermediate_w)}x{int(intermediate_h)}.")
 
 
 
 
 
 
 
126
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
127
- was_resized = True
128
 
 
129
  w_proc, h_proc = input_image_to_process.size
130
- w_final_proc = max(8, w_proc - w_proc % 8)
131
- h_final_proc = max(8, h_proc - h_proc % 8)
 
132
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
133
- logging.info(f"Rounding processed input dims {w_proc}x{h_proc} -> {w_final_proc}x{h_final_proc}")
134
  input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)
135
 
136
  return input_image_to_process, w_original, h_original, was_resized
137
 
138
 
139
  # --- Inference Function (infer) ---
140
- @spaces.GPU(duration=180)
 
141
  def infer(
142
- seed, randomize_seed, input_image, num_inference_steps,
143
- final_upscale_factor, controlnet_conditioning_scale,
144
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
145
  ):
 
146
  global pipe
147
- # --- IMPROVED ERROR HANDLING: Define default return early ---
148
- default_return = [[input_image, None], int(seed) if seed is not None else 0, None]
 
 
149
 
150
  if pipe is None:
151
  gr.Error("Pipeline not loaded. Cannot perform inference.")
152
- return default_return # Use default return
153
 
154
- original_input_pil = input_image # Keep ref even if None initially
155
 
 
156
  if input_image is None:
157
  gr.Warning("Please provide an input image.")
158
- # Update seed in default return if randomized
159
- if randomize_seed: seed = random.randint(0, MAX_SEED)
160
- else: seed = int(seed) if seed is not None else 0
161
- default_return[1] = seed
162
- return default_return # Use default return
163
-
164
- if randomize_seed: seed = random.randint(0, MAX_SEED)
165
- seed = int(seed)
166
- # --- UPDATE DEFAULT RETURN SEED ---
 
 
167
  default_return[1] = seed
168
- # --- Ensure original image is in the default return ---
169
  default_return[0][0] = original_input_pil
170
 
171
-
172
  final_upscale_factor = int(final_upscale_factor)
173
  num_inference_steps = int(num_inference_steps)
174
 
 
175
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
176
  gr.Warning(f"Clamping final upscale factor {final_upscale_factor}x to internal {INTERNAL_PROCESSING_FACTOR}x.")
177
  final_upscale_factor = INTERNAL_PROCESSING_FACTOR
178
 
179
- logging.info(f"Starting inference: seed={seed}, internal={INTERNAL_PROCESSING_FACTOR}x, final={final_upscale_factor}x, steps={num_inference_steps}, cnet_scale={controlnet_conditioning_scale}")
 
 
 
 
180
 
 
181
  try:
182
- processed_input_image, w_original, h_original, was_input_resized = process_input(input_image)
 
 
183
  except Exception as e:
184
  logging.error(f"Error processing input image: {e}", exc_info=True)
185
  gr.Error(f"Error processing input image: {e}")
186
- return default_return # Use default return with correct seed
187
 
 
188
  w_proc, h_proc = processed_input_image.size
189
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
190
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
191
 
192
- # Failsafe clamp (remains the same)
193
- if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05:
194
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
195
- control_image_w = max(8, int(control_image_w * scale_factor)); control_image_w -= control_image_w % 8
196
- control_image_h = max(8, int(control_image_h * scale_factor)); control_image_h -= control_image_h % 8
197
- logging.warning(f"Control image dims clamped post-processing: {control_image_w}x{control_image_h}.")
 
 
 
198
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
199
 
 
200
  logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h}")
201
  try:
202
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
203
  except ValueError as resize_err:
204
- logging.error(f"Error resizing to control image: {resize_err}")
205
  gr.Error(f"Failed to prepare control image: {resize_err}")
206
- return default_return # Use default return
207
 
 
208
  generator = torch.Generator(device=device).manual_seed(seed)
209
 
210
- gr.Info(f"Generating intermediate image ({control_image_w}x{control_image_h}, {num_inference_steps} steps)...")
 
211
  logging.info(f"Running pipeline: size={control_image_w}x{control_image_h}, steps={num_inference_steps}")
212
- intermediate_result_image = None # Initialize
213
  try:
214
  with torch.inference_mode():
215
- # Progress bar integration can be added here if needed
216
  intermediate_result_image = pipe(
217
- prompt="", control_image=control_image,
 
218
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
219
- num_inference_steps=num_inference_steps, guidance_scale=0.0,
220
- height=control_image_h, width=control_image_w, generator=generator,
 
 
 
 
 
221
  ).images[0]
222
- logging.info(f"Pipeline finished. Intermediate size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
223
 
224
- # --- Catch specific errors if needed, otherwise general Exception ---
225
  except torch.cuda.OutOfMemoryError as oom_error:
226
- logging.error(f"OOM during pipeline: {oom_error}", exc_info=True)
227
- gr.Error(f"OOM generating {control_image_w}x{control_image_h}. Try smaller input/factor.")
228
- if device == 'cuda': torch.cuda.empty_cache()
229
- return default_return # Use default return
230
- except Exception as e: # Catches the torch.compile error too
231
- logging.error(f"Error during pipeline execution: {e}", exc_info=True) # Log full traceback
232
- # Provide a more specific error message if it's the known compile issue
 
 
233
  if "dynamic shape operator" in str(e) or "Unsupported" in str(e.__class__):
234
- gr.Error(f"Inference failed: torch.compile issue encountered ({type(e).__name__}). Try restarting the Space or disabling compilation if persistent.")
235
  else:
236
- gr.Error(f"Inference failed: {e}")
237
- return default_return # Use default return
238
 
239
- # --- Check if intermediate image was actually created ---
240
  if not intermediate_result_image:
241
- logging.error("Intermediate result is None after pipeline (but no exception caught).")
 
242
  gr.Error("Inference produced no result image unexpectedly.")
243
- return default_return # Use default return
244
 
245
- # --- Final Resizing (remains the same logic) ---
246
  if was_input_resized:
 
247
  final_target_w = w_proc * final_upscale_factor
248
  final_target_h = h_proc * final_upscale_factor
249
- logging.warning(f"Input downscaled. Final size based on processed: {w_proc}x{h_proc}*{final_upscale_factor}x -> {final_target_w}x{final_target_h}")
250
- gr.Info(f"Input downscaled. Final size approx {final_target_w}x{final_target_h}.")
251
  else:
 
252
  final_target_w = w_original * final_upscale_factor
253
  final_target_h = h_original * final_upscale_factor
254
 
 
255
  final_result_image = intermediate_result_image
256
  current_w, current_h = intermediate_result_image.size
257
 
258
  if (current_w, current_h) != (final_target_w, final_target_h):
259
- logging.info(f"Resizing intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}")
260
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
261
  try:
 
262
  if final_target_w > 0 and final_target_h > 0:
 
263
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
264
  else:
265
- gr.Warning(f"Invalid final target ({final_target_w}x{final_target_h}). Skipping resize.")
266
- final_result_image = intermediate_result_image
 
267
  except Exception as resize_e:
268
- logging.error(f"Could not resize final image: {resize_e}")
269
- gr.Warning(f"Failed final resize. Returning intermediate {current_w}x{current_h}.")
270
- final_result_image = intermediate_result_image
 
271
  else:
272
- logging.info("Intermediate size matches final target. No final resize needed.")
 
273
 
274
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
275
 
276
- # --- Base64 Encoding (remains the same) ---
277
  base64_string = None
278
  if final_result_image:
279
  try:
280
- buffered = io.BytesIO(); final_result_image.save(buffered, format="WEBP", quality=90)
 
 
281
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
282
  base64_string = f"data:image/webp;base64,{img_str}"
283
- logging.info(f"Encoded result to Base64 (len: {len(base64_string)}).")
284
  except Exception as enc_err:
285
- logging.error(f"Failed Base64 encoding: {enc_err}", exc_info=True)
 
 
286
 
287
- # --- SUCCESS RETURN ---
288
  return [[original_input_pil, final_result_image], seed, base64_string]
289
 
290
- # --- Gradio Interface (Gradio UI Definition) ---
291
- # ... (Gradio definition remains the same, ensure inputs/outputs match infer) ...
292
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
293
- gr.Markdown(f"""
 
294
  # ⚡ Flux.1-dev Upscaler ControlNet ⚡
295
- Upscale images using Flux.1-dev Upscaler ControlNet on **{power_device}**.
296
- Internal processing at **{INTERNAL_PROCESSING_FACTOR}x** quality, resized to **Final Upscale Factor**.
297
- **Speed Up:** Reduce `Inference Steps` (try 10-15). `torch.compile` enabled (may fail, see logs).
298
- *Limit*: ~**{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** intermediate size.
299
- """)
 
 
 
 
 
 
 
 
 
 
300
  with gr.Row():
301
- with gr.Column(scale=2): input_im = gr.Image(label="Input Image", type="pil", height=350, sources=["upload", "clipboard"])
 
 
 
 
 
 
302
  with gr.Column(scale=1):
303
- upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size. Internal uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=min(2, INTERNAL_PROCESSING_FACTOR))
304
- num_inference_steps = gr.Slider(label="Inference Steps", info="Fewer=faster (try 10-15).", minimum=4, maximum=50, step=1, value=15)
305
- controlnet_conditioning_scale = gr.Slider(label="ControlNet Scale", info="Guidance strength.", minimum=0.0, maximum=1.5, step=0.05, value=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  with gr.Row():
307
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
308
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
309
  run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)
310
- with gr.Row(): result_slider = ImageSlider(label="Input / Output Comparison", type="pil", interactive=False, show_label=True, position=0.5)
 
 
 
 
 
 
 
 
 
311
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
 
312
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
313
 
314
- example_dir = "examples"; example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
 
 
315
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
 
316
  if example_paths:
317
  gr.Examples(
 
318
  examples=[ [path, min(2, INTERNAL_PROCESSING_FACTOR), 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
319
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
320
- outputs=[result_slider, output_seed], fn=infer, cache_examples="lazy",
321
- label=f"Examples (Click: {min(2, INTERNAL_PROCESSING_FACTOR)}x Output, 15 Steps)", run_on_click=True)
322
- else: gr.Markdown(f"*No example images found in '{example_dir}'.*")
323
- gr.Markdown("---"); gr.Markdown("**Disclaimer:** For illustrative purposes.")
324
-
325
- run_button.click(fn=infer,
326
- inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor_slider, controlnet_conditioning_scale],
327
- outputs=[result_slider, output_seed, api_base64_output], api_name="upscale")
328
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  demo.queue(max_size=10).launch(share=False, show_api=True)
 
 
1
  import logging
2
  import random
3
  import warnings
 
6
  import base64
7
  import gradio as gr
8
  import numpy as np
9
+ import spaces # Ensure spaces is imported for the decorator
10
  import torch
11
+ # --- Optional: Add this if trying Solution 2 (suppress_errors) later ---
12
  # import torch._dynamo
13
  from diffusers import FluxControlNetModel
14
  from diffusers.pipelines import FluxControlNetPipeline
 
17
  from huggingface_hub import snapshot_download
18
 
19
  # --- Setup Logging and Device ---
 
20
  logging.basicConfig(level=logging.INFO)
21
  warnings.filterwarnings("ignore")
22
 
 
26
  """
27
 
28
  if torch.cuda.is_available():
29
+ power_device = "GPU"
30
+ device = "cuda"
31
+ torch_dtype = torch.bfloat16
32
  else:
33
+ power_device = "CPU"
34
+ device = "cpu"
35
+ torch_dtype = torch.float32
36
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
37
 
38
  # --- Authentication and Model Download ---
 
43
  pipe = None
44
 
45
  try:
 
46
  logging.info(f"Downloading base model: {flux_model_id}")
47
+ model_path = snapshot_download(
48
+ repo_id=flux_model_id,
49
+ repo_type="model",
50
+ ignore_patterns=["*.md", "*.gitattributes"],
51
+ local_dir=local_model_dir,
52
+ token=huggingface_token
53
+ )
54
  logging.info(f"Base model downloaded/verified in: {model_path}")
55
 
56
  logging.info(f"Loading ControlNet model: {controlnet_model_id}")
57
+ controlnet = FluxControlNetModel.from_pretrained(
58
+ controlnet_model_id, torch_dtype=torch_dtype
59
+ ).to(device)
60
  logging.info("ControlNet model loaded.")
61
 
62
  logging.info("Loading FluxControlNetPipeline...")
63
+ pipe = FluxControlNetPipeline.from_pretrained(
64
+ model_path,
65
+ controlnet=controlnet,
66
+ torch_dtype=torch_dtype
67
+ )
68
  pipe.to(device)
69
  logging.info("Pipeline loaded and moved to device.")
70
 
71
  # --- OPTIMIZATION: Attempt torch.compile (PyTorch 2.0+) ---
72
  if device == "cuda" and hasattr(torch, "compile"):
73
+ # --- Using "default" mode as first attempt to fix compile errors ---
74
  compile_mode = "default"
75
+ # --- Alternative (Solution 2): Uncomment these lines if "default" fails ---
76
  # import torch._dynamo
77
  # torch._dynamo.config.suppress_errors = True
78
  # compile_mode = "max-autotune" # or "default" even with suppress_errors
 
82
  try:
83
  pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True)
84
  logging.info("Pipeline transformer compiled successfully.")
85
+ # Optional dummy inference run can go here for pre-compilation
86
  except Exception as e:
87
  logging.warning(f"torch.compile failed (mode='{compile_mode}'): {e}. Running unoptimized.")
88
+ # --- Fallback (Solution 3): If compilation fails consistently, comment out the compile line above ---
89
  # pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True) # <-- Comment this out
90
  else:
91
  logging.info("torch.compile not available or not on CUDA, skipping compilation.")
92
 
93
+ # --- (Optional xformers code could be added here if used) ---
94
 
95
  logging.info("Pipeline ready for inference.")
96
 
97
  except Exception as e:
98
+ # Log the full error traceback for debugging
99
  logging.error(f"FATAL: Error during model loading or setup: {e}", exc_info=True)
100
+ # Print a simple message to console as well
101
  print(f"FATAL ERROR DURING MODEL LOAD/SETUP: {e}")
102
+ # Exit if models can't load - Gradio app won't work anyway
103
  raise SystemExit(f"Model loading/setup failed: {e}")
104
 
 
105
  # --- Constants ---
106
  MAX_SEED = 2**32 - 1
107
+ MAX_PIXEL_BUDGET = 1280 * 1280 # Max pixels for the *intermediate* image
108
+ INTERNAL_PROCESSING_FACTOR = 4 # Factor used for quality generation
109
 
110
  # --- Image Processing Function (process_input) ---
 
111
  def process_input(input_image):
112
+ """Processes input image: handles orientation, converts to RGB, checks budget, rounds dimensions."""
113
+ if input_image is None:
114
+ raise gr.Error("Input image is missing!")
115
  try:
116
+ # Ensure it's a PIL Image and handle EXIF orientation
117
  input_image = ImageOps.exif_transpose(input_image)
118
+ # Convert to RGB if needed
119
+ if input_image.mode != 'RGB':
120
+ logging.info(f"Converting input image from {input_image.mode} to RGB")
121
+ input_image = input_image.convert('RGB')
122
  w, h = input_image.size
123
+ except AttributeError:
124
+ # Catch cases where input_image might not be a valid PIL Image object
125
+ raise gr.Error("Invalid input image format. Please provide a valid image file.")
126
+ except Exception as img_err:
127
+ # Catch other potential PIL errors
128
+ raise gr.Error(f"Could not process input image: {img_err}")
129
 
130
  w_original, h_original = w, h
131
+ if w == 0 or h == 0:
132
+ raise gr.Error("Input image has zero width or height.")
133
 
134
+ # Calculate target intermediate size based on INTERNAL factor for budget check
135
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
136
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
137
  target_pixels_internal = target_w_internal * target_h_internal
138
+
139
  was_resized = False
140
+ input_image_to_process = input_image.copy() # Work on a copy
141
 
142
+ # Check if the *intermediate* size exceeds the budget
143
  if target_pixels_internal > MAX_PIXEL_BUDGET:
144
+ # Calculate the maximum allowed input pixels for the given internal factor
145
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
146
  current_input_pixels = w * h
147
+
148
  if current_input_pixels > max_input_pixels:
149
+ # Calculate scaling factor to fit budget and resize input
150
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
151
+ input_w_resized = max(8, int(w * input_scale_factor)) # Ensure min size 8
152
+ input_h_resized = max(8, int(h * input_scale_factor)) # Ensure min size 8
153
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
154
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
155
+
156
+ logging.warning(
157
+ f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. "
158
+ f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}."
159
+ )
160
+ gr.Info(
161
+ f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} "
162
+ f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
163
+ )
164
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
165
+ was_resized = True # Flag that original dimensions were lost for final scaling
166
 
167
+ # Round processed input dimensions down to nearest multiple of 8 (required by some models)
168
  w_proc, h_proc = input_image_to_process.size
169
+ w_final_proc = max(8, w_proc - w_proc % 8) # Ensure minimum 8x8
170
+ h_final_proc = max(8, h_proc - h_proc % 8) # Ensure minimum 8x8
171
+
172
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
173
+ logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
174
  input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)
175
 
176
  return input_image_to_process, w_original, h_original, was_resized
177
 
178
 
179
  # --- Inference Function (infer) ---
180
+ # --- MODIFIED GPU DURATION ---
181
+ @spaces.GPU(duration=75)
182
  def infer(
183
+ seed,
184
+ randomize_seed,
185
+ input_image,
186
+ num_inference_steps,
187
+ final_upscale_factor,
188
+ controlnet_conditioning_scale,
189
+ progress=gr.Progress(track_tqdm=True), # Gradio progress tracking
190
  ):
191
+ """Runs the Flux ControlNet upscaling pipeline."""
192
  global pipe
193
+ # --- Define default return structure for error cases ---
194
+ # [[Input Image or None, Output Image or None], Seed Int, Base64 String or None]
195
+ current_seed = int(seed) if seed is not None else 0
196
+ default_return = [[input_image, None], current_seed, None]
197
 
198
  if pipe is None:
199
  gr.Error("Pipeline not loaded. Cannot perform inference.")
200
+ return default_return
201
 
202
+ original_input_pil = input_image # Keep reference to original input
203
 
204
+ # Handle missing input image
205
  if input_image is None:
206
  gr.Warning("Please provide an input image.")
207
+ # Update seed in default return if randomized, keep original image as None
208
+ if randomize_seed: current_seed = random.randint(0, MAX_SEED)
209
+ default_return[1] = current_seed
210
+ default_return[0][0] = None # Explicitly set original image part to None
211
+ return default_return
212
+
213
+ # Determine seed
214
+ if randomize_seed:
215
+ current_seed = random.randint(0, MAX_SEED)
216
+ seed = int(current_seed) # Use the final determined seed
217
+ # Update default return with final seed and original image
218
  default_return[1] = seed
 
219
  default_return[0][0] = original_input_pil
220
 
221
+ # Ensure numerical inputs are integers
222
  final_upscale_factor = int(final_upscale_factor)
223
  num_inference_steps = int(num_inference_steps)
224
 
225
+ # Clamp final factor if needed
226
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
227
  gr.Warning(f"Clamping final upscale factor {final_upscale_factor}x to internal {INTERNAL_PROCESSING_FACTOR}x.")
228
  final_upscale_factor = INTERNAL_PROCESSING_FACTOR
229
 
230
+ logging.info(
231
+ f"Starting inference: seed={seed}, internal={INTERNAL_PROCESSING_FACTOR}x, "
232
+ f"final={final_upscale_factor}x, steps={num_inference_steps}, "
233
+ f"cnet_scale={controlnet_conditioning_scale}"
234
+ )
235
 
236
+ # Process the input image
237
  try:
238
+ processed_input_image, w_original, h_original, was_input_resized = process_input(
239
+ input_image
240
+ )
241
  except Exception as e:
242
  logging.error(f"Error processing input image: {e}", exc_info=True)
243
  gr.Error(f"Error processing input image: {e}")
244
+ return default_return # Use default return structure
245
 
246
+ # Calculate intermediate dimensions for the model
247
  w_proc, h_proc = processed_input_image.size
248
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
249
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
250
 
251
+ # Failsafe clamp if budget is still somehow exceeded after input processing
252
+ if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add 5% margin just in case
253
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
254
+ control_image_w = max(8, int(control_image_w * scale_factor))
255
+ control_image_h = max(8, int(control_image_h * scale_factor))
256
+ # Ensure multiple of 8 after scaling
257
+ control_image_w = max(8, control_image_w - control_image_w % 8)
258
+ control_image_h = max(8, control_image_h - control_image_h % 8)
259
+ logging.warning(f"Control image dimensions clamped post-processing to {control_image_w}x{control_image_h} to fit budget.")
260
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
261
 
262
+ # Prepare control image (resized input for ControlNet)
263
  logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h}")
264
  try:
265
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
266
  except ValueError as resize_err:
267
+ logging.error(f"Error resizing processed input to control image: {resize_err}")
268
  gr.Error(f"Failed to prepare control image: {resize_err}")
269
+ return default_return
270
 
271
+ # Setup generator for reproducibility
272
  generator = torch.Generator(device=device).manual_seed(seed)
273
 
274
+ # --- Run the Diffusion Pipeline ---
275
+ gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h}) with {num_inference_steps} steps...")
276
  logging.info(f"Running pipeline: size={control_image_w}x{control_image_h}, steps={num_inference_steps}")
277
+ intermediate_result_image = None # Initialize to None
278
  try:
279
  with torch.inference_mode():
280
+ # The actual model inference call
281
  intermediate_result_image = pipe(
282
+ prompt="", # No text prompt needed for this upscaler
283
+ control_image=control_image,
284
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
285
+ num_inference_steps=num_inference_steps,
286
+ guidance_scale=0.0, # Guidance scale typically 0 for ControlNet-only tasks
287
+ height=control_image_h, # Target height for the model
288
+ width=control_image_w, # Target width for the model
289
+ generator=generator,
290
+ # Can add callback for step progress if needed:
291
+ # callback_on_step_end=lambda step, t, latents: progress(step / num_inference_steps)
292
  ).images[0]
293
+ logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
294
 
 
295
  except torch.cuda.OutOfMemoryError as oom_error:
296
+ # Handle specific OOM error
297
+ logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
298
+ gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}. Try reducing the Final Upscale Factor or using a smaller input image.")
299
+ if device == 'cuda': torch.cuda.empty_cache() # Try to clear cache
300
+ return default_return
301
+ except Exception as e:
302
+ # Handle other pipeline errors, including potential torch.compile issues
303
+ logging.error(f"Error during pipeline execution: {e}", exc_info=True)
304
+ # Check if it looks like the known compile error
305
  if "dynamic shape operator" in str(e) or "Unsupported" in str(e.__class__):
306
+ gr.Error(f"Inference failed: torch.compile issue encountered ({type(e).__name__}). The model attempted to run unoptimized. If this persists, compilation might need to be disabled in the code.")
307
  else:
308
+ gr.Error(f"Inference failed: {e}")
309
+ return default_return
310
 
311
+ # --- Post-Pipeline Checks and Resizing ---
312
  if not intermediate_result_image:
313
+ # Should ideally not happen if no exception was caught, but check anyway
314
+ logging.error("Intermediate result image is None after pipeline execution without exception.")
315
  gr.Error("Inference produced no result image unexpectedly.")
316
+ return default_return
317
 
318
+ # Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor
319
  if was_input_resized:
320
+ # Base final size on the downscaled input that was processed
321
  final_target_w = w_proc * final_upscale_factor
322
  final_target_h = h_proc * final_upscale_factor
323
+ logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}")
324
+ gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.")
325
  else:
326
+ # Base final size on the original input size
327
  final_target_w = w_original * final_upscale_factor
328
  final_target_h = h_original * final_upscale_factor
329
 
330
+ # Perform final resize from intermediate to target size
331
  final_result_image = intermediate_result_image
332
  current_w, current_h = intermediate_result_image.size
333
 
334
  if (current_w, current_h) != (final_target_w, final_target_h):
335
+ logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
336
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
337
  try:
338
+ # Ensure target dimensions are valid before resizing
339
  if final_target_w > 0 and final_target_h > 0:
340
+ # Use high-quality LANCZOS for downsampling or general resizing
341
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
342
  else:
343
+ # Avoid resizing if target dimensions are invalid
344
+ gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
345
+ final_result_image = intermediate_result_image # Keep intermediate
346
  except Exception as resize_e:
347
+ # Handle potential errors during final resize
348
+ logging.error(f"Could not resize intermediate image to final size: {resize_e}")
349
+ gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
350
+ final_result_image = intermediate_result_image # Fallback to intermediate
351
  else:
352
+ # No resize needed if intermediate matches final target
353
+ logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")
354
 
355
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
356
 
357
+ # --- Base64 Encoding for API output ---
358
  base64_string = None
359
  if final_result_image:
360
  try:
361
+ buffered = io.BytesIO()
362
+ # Save as WEBP for potentially smaller size, adjust quality as needed
363
+ final_result_image.save(buffered, format="WEBP", quality=90)
364
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
365
+ # Format as a data URL
366
  base64_string = f"data:image/webp;base64,{img_str}"
367
+ logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
368
  except Exception as enc_err:
369
+ # Log if encoding fails but don't stop the process
370
+ logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)
371
+ base64_string = None # Ensure it's None if encoding failed
372
 
373
+ # --- Return results for Gradio ---
374
  return [[original_input_pil, final_result_image], seed, base64_string]
375
 
376
+
377
+ # --- Gradio Interface Definition ---
378
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
379
+ gr.Markdown(
380
+ f"""
381
  # ⚡ Flux.1-dev Upscaler ControlNet ⚡
382
+ Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
383
+ Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.
384
+
385
+ **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation (slower),
386
+ then resizes the result to your selected **Final Upscale Factor**.
387
+
388
+ **To Speed Up:**
389
+ 1. **Reduce `Inference Steps`:** Fewer steps = faster generation (try 10-15).
390
+ 2. **(Code Change Needed):** Reduce `INTERNAL_PROCESSING_FACTOR` in the script (e.g., to 3) for direct computation reduction (may lower detail).
391
+ 3. `torch.compile` is attempted (`mode="default"`) which might provide speedup after the first run (check logs for success/failure).
392
+
393
+ *Limit*: Intermediate processing resolution capped around **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}).
394
+ """
395
+ )
396
+
397
  with gr.Row():
398
+ with gr.Column(scale=2):
399
+ input_im = gr.Image(
400
+ label="Input Image",
401
+ type="pil",
402
+ height=350,
403
+ sources=["upload", "clipboard"],
404
+ )
405
  with gr.Column(scale=1):
406
+ upscale_factor_slider = gr.Slider(
407
+ label="Final Upscale Factor",
408
+ info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.",
409
+ minimum=1,
410
+ maximum=INTERNAL_PROCESSING_FACTOR, # Max limited by internal factor
411
+ step=1,
412
+ value=min(2, INTERNAL_PROCESSING_FACTOR) # Default to 2x or internal factor if smaller
413
+ )
414
+ num_inference_steps = gr.Slider(
415
+ label="Inference Steps",
416
+ info="Fewer steps = faster. Try 10-15.",
417
+ minimum=4, maximum=50, step=1, value=15 # Defaulting to 15 for speed
418
+ )
419
+ controlnet_conditioning_scale = gr.Slider(
420
+ label="ControlNet Conditioning Scale",
421
+ info="Strength of ControlNet guidance.",
422
+ minimum=0.0, maximum=1.5, step=0.05, value=0.6
423
+ )
424
  with gr.Row():
425
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
426
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
427
  run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)
428
+
429
+ with gr.Row():
430
+ result_slider = ImageSlider(
431
+ label="Input / Output Comparison",
432
+ type="pil",
433
+ interactive=False, # Output only
434
+ show_label=True,
435
+ position=0.5 # Start slider in the middle
436
+ )
437
+
438
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
439
+ # Hidden output for API usage
440
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
441
 
442
+ # --- Examples ---
443
+ example_dir = "examples"
444
+ example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
445
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
446
+
447
  if example_paths:
448
  gr.Examples(
449
+ # Examples use the new defaults: final factor 2x (or less), steps 15
450
  examples=[ [path, min(2, INTERNAL_PROCESSING_FACTOR), 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
451
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
452
+ # Map outputs to UI components; base64 ignored by Examples UI
453
+ outputs=[result_slider, output_seed],
454
+ fn=infer, # Function to call when example is clicked
455
+ cache_examples="lazy", # Cache results for examples
456
+ label=f"Example Images (Click to Run with {min(2, INTERNAL_PROCESSING_FACTOR)}x Output, 15 Steps)",
457
+ run_on_click=True
458
+ )
459
+ else:
460
+ gr.Markdown(f"*No example images found in '{example_dir}' directory.*")
461
+
462
+ gr.Markdown("---")
463
+ gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")
464
+
465
+ # Connect button click to the inference function
466
+ run_button.click(
467
+ fn=infer,
468
+ inputs=[
469
+ seed,
470
+ randomize_seed,
471
+ input_im,
472
+ num_inference_steps,
473
+ upscale_factor_slider,
474
+ controlnet_conditioning_scale,
475
+ ],
476
+ # Map all return values from infer to the correct output components
477
+ outputs=[result_slider, output_seed, api_base64_output],
478
+ api_name="upscale" # Define API endpoint name
479
+ )
480
+
481
+ # Launch the Gradio app
482
+ # queue manages concurrent users/requests
483
+ # share=False means no public link generated by default
484
  demo.queue(max_size=10).launch(share=False, show_api=True)