Gemini899 commited on
Commit
ae94738
·
verified ·
1 Parent(s): 1f71182

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -190
app.py CHANGED
@@ -6,14 +6,12 @@ import io
6
  import base64
7
  import gradio as gr
8
  import numpy as np
9
- import spaces # Ensure spaces is imported for the decorator
10
  import torch
11
- # --- Optional: Add this if trying Solution 2 (suppress_errors) later ---
12
- # import torch._dynamo
13
  from diffusers import FluxControlNetModel
14
  from diffusers.pipelines import FluxControlNetPipeline
15
- from gradio_imageslider import ImageSlider
16
- from PIL import Image, ImageOps
17
  from huggingface_hub import snapshot_download
18
 
19
  # --- Setup Logging and Device ---
@@ -21,25 +19,34 @@ logging.basicConfig(level=logging.INFO)
21
  warnings.filterwarnings("ignore")
22
 
23
  css = """
24
- #col-container { margin: 0 auto; max-width: 512px; }
25
- .gradio-container { max-width: 900px !important; margin: auto !important; }
 
 
 
 
 
 
26
  """
27
 
28
  if torch.cuda.is_available():
29
  power_device = "GPU"
30
  device = "cuda"
31
- torch_dtype = torch.bfloat16
32
  else:
33
  power_device = "CPU"
34
  device = "cpu"
35
- torch_dtype = torch.float32
 
36
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
37
 
38
  # --- Authentication and Model Download ---
39
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
40
  flux_model_id = "black-forest-labs/FLUX.1-dev"
41
  controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
42
- local_model_dir = flux_model_id.split('/')[-1]
43
  pipe = None
44
 
45
  try:
@@ -49,7 +56,7 @@ try:
49
  repo_type="model",
50
  ignore_patterns=["*.md", "*.gitattributes"],
51
  local_dir=local_model_dir,
52
- token=huggingface_token
53
  )
54
  logging.info(f"Base model downloaded/verified in: {model_path}")
55
 
@@ -68,88 +75,60 @@ try:
68
  pipe.to(device)
69
  logging.info("Pipeline loaded and moved to device.")
70
 
71
- # --- OPTIMIZATION: Attempt torch.compile (PyTorch 2.0+) ---
72
- if device == "cuda" and hasattr(torch, "compile"):
73
- # --- Using "default" mode as first attempt to fix compile errors ---
74
- compile_mode = "default"
75
- # --- Alternative (Solution 2): Uncomment these lines if "default" fails ---
76
- # import torch._dynamo
77
- # torch._dynamo.config.suppress_errors = True
78
- # compile_mode = "max-autotune" # or "default" even with suppress_errors
79
- # --- End Alternative ---
80
-
81
- logging.info(f"Attempting to compile the pipeline transformer with torch.compile (mode='{compile_mode}')...")
82
- try:
83
- pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True)
84
- logging.info("Pipeline transformer compiled successfully.")
85
- # Optional dummy inference run can go here for pre-compilation
86
- except Exception as e:
87
- logging.warning(f"torch.compile failed (mode='{compile_mode}'): {e}. Running unoptimized.")
88
- # --- Fallback (Solution 3): If compilation fails consistently, comment out the compile line above ---
89
- # pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True) # <-- Comment this out
90
- else:
91
- logging.info("torch.compile not available or not on CUDA, skipping compilation.")
92
-
93
- # --- (Optional xformers code could be added here if used) ---
94
-
95
- logging.info("Pipeline ready for inference.")
96
-
97
  except Exception as e:
98
- # Log the full error traceback for debugging
99
- logging.error(f"FATAL: Error during model loading or setup: {e}", exc_info=True)
100
- # Print a simple message to console as well
101
- print(f"FATAL ERROR DURING MODEL LOAD/SETUP: {e}")
102
- # Exit if models can't load - Gradio app won't work anyway
103
- raise SystemExit(f"Model loading/setup failed: {e}")
104
 
105
  # --- Constants ---
106
  MAX_SEED = 2**32 - 1
107
- MAX_PIXEL_BUDGET = 1280 * 1280 # Max pixels for the *intermediate* image
108
- INTERNAL_PROCESSING_FACTOR = 4 # Factor used for quality generation
 
109
 
110
- # --- Image Processing Function (process_input) ---
111
  def process_input(input_image):
112
- """Processes input image: handles orientation, converts to RGB, checks budget, rounds dimensions."""
 
113
  if input_image is None:
114
  raise gr.Error("Input image is missing!")
115
  try:
116
- # Ensure it's a PIL Image and handle EXIF orientation
117
  input_image = ImageOps.exif_transpose(input_image)
118
- # Convert to RGB if needed
119
  if input_image.mode != 'RGB':
120
  logging.info(f"Converting input image from {input_image.mode} to RGB")
121
  input_image = input_image.convert('RGB')
122
  w, h = input_image.size
123
  except AttributeError:
124
- # Catch cases where input_image might not be a valid PIL Image object
125
  raise gr.Error("Invalid input image format. Please provide a valid image file.")
126
  except Exception as img_err:
127
- # Catch other potential PIL errors
128
  raise gr.Error(f"Could not process input image: {img_err}")
129
 
130
  w_original, h_original = w, h
131
  if w == 0 or h == 0:
132
  raise gr.Error("Input image has zero width or height.")
133
 
134
- # Calculate target intermediate size based on INTERNAL factor for budget check
135
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
136
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
137
  target_pixels_internal = target_w_internal * target_h_internal
138
 
139
  was_resized = False
140
- input_image_to_process = input_image.copy() # Work on a copy
141
 
142
  # Check if the *intermediate* size exceeds the budget
143
  if target_pixels_internal > MAX_PIXEL_BUDGET:
144
- # Calculate the maximum allowed input pixels for the given internal factor
145
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
146
  current_input_pixels = w * h
147
 
148
  if current_input_pixels > max_input_pixels:
149
- # Calculate scaling factor to fit budget and resize input
150
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
151
- input_w_resized = max(8, int(w * input_scale_factor)) # Ensure min size 8
152
- input_h_resized = max(8, int(h * input_scale_factor)) # Ensure min size 8
 
 
 
153
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
154
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
155
 
@@ -162,12 +141,12 @@ def process_input(input_image):
162
  f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
163
  )
164
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
165
- was_resized = True # Flag that original dimensions were lost for final scaling
166
 
167
- # Round processed input dimensions down to nearest multiple of 8 (required by some models)
168
  w_proc, h_proc = input_image_to_process.size
169
- w_final_proc = max(8, w_proc - w_proc % 8) # Ensure minimum 8x8
170
- h_final_proc = max(8, h_proc - h_proc % 8) # Ensure minimum 8x8
171
 
172
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
173
  logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
@@ -175,147 +154,120 @@ def process_input(input_image):
175
 
176
  return input_image_to_process, w_original, h_original, was_resized
177
 
178
-
179
- # --- Inference Function (infer) ---
180
- # --- MODIFIED GPU DURATION ---
181
  @spaces.GPU(duration=75)
182
  def infer(
183
  seed,
184
  randomize_seed,
185
  input_image,
186
  num_inference_steps,
187
- final_upscale_factor,
188
  controlnet_conditioning_scale,
189
- progress=gr.Progress(track_tqdm=True), # Gradio progress tracking
190
  ):
191
- """Runs the Flux ControlNet upscaling pipeline."""
192
  global pipe
193
- # --- Define default return structure for error cases ---
194
- # [[Input Image or None, Output Image or None], Seed Int, Base64 String or None]
195
- current_seed = int(seed) if seed is not None else 0
196
- default_return = [[input_image, None], current_seed, None]
197
-
198
  if pipe is None:
199
  gr.Error("Pipeline not loaded. Cannot perform inference.")
200
- return default_return
201
 
202
- original_input_pil = input_image # Keep reference to original input
203
 
204
- # Handle missing input image
205
  if input_image is None:
206
  gr.Warning("Please provide an input image.")
207
- # Update seed in default return if randomized, keep original image as None
208
- if randomize_seed: current_seed = random.randint(0, MAX_SEED)
209
- default_return[1] = current_seed
210
- default_return[0][0] = None # Explicitly set original image part to None
211
- return default_return
212
 
213
- # Determine seed
214
  if randomize_seed:
215
- current_seed = random.randint(0, MAX_SEED)
216
- seed = int(current_seed) # Use the final determined seed
217
- # Update default return with final seed and original image
218
- default_return[1] = seed
219
- default_return[0][0] = original_input_pil
220
 
221
- # Ensure numerical inputs are integers
222
  final_upscale_factor = int(final_upscale_factor)
223
- num_inference_steps = int(num_inference_steps)
224
-
225
- # Clamp final factor if needed
226
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
227
- gr.Warning(f"Clamping final upscale factor {final_upscale_factor}x to internal {INTERNAL_PROCESSING_FACTOR}x.")
228
- final_upscale_factor = INTERNAL_PROCESSING_FACTOR
 
229
 
230
  logging.info(
231
- f"Starting inference: seed={seed}, internal={INTERNAL_PROCESSING_FACTOR}x, "
232
- f"final={final_upscale_factor}x, steps={num_inference_steps}, "
233
- f"cnet_scale={controlnet_conditioning_scale}"
 
234
  )
235
 
236
- # Process the input image
237
  try:
 
238
  processed_input_image, w_original, h_original, was_input_resized = process_input(
239
  input_image
240
  )
241
  except Exception as e:
242
  logging.error(f"Error processing input image: {e}", exc_info=True)
243
  gr.Error(f"Error processing input image: {e}")
244
- return default_return # Use default return structure
245
 
246
- # Calculate intermediate dimensions for the model
247
  w_proc, h_proc = processed_input_image.size
 
 
248
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
249
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
250
 
251
- # Failsafe clamp if budget is still somehow exceeded after input processing
252
- if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add 5% margin just in case
 
253
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
254
  control_image_w = max(8, int(control_image_w * scale_factor))
255
  control_image_h = max(8, int(control_image_h * scale_factor))
256
- # Ensure multiple of 8 after scaling
257
  control_image_w = max(8, control_image_w - control_image_w % 8)
258
  control_image_h = max(8, control_image_h - control_image_h % 8)
259
- logging.warning(f"Control image dimensions clamped post-processing to {control_image_w}x{control_image_h} to fit budget.")
260
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
261
 
262
- # Prepare control image (resized input for ControlNet)
263
- logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h}")
264
  try:
 
265
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
266
  except ValueError as resize_err:
267
  logging.error(f"Error resizing processed input to control image: {resize_err}")
268
  gr.Error(f"Failed to prepare control image: {resize_err}")
269
- return default_return
270
 
271
- # Setup generator for reproducibility
272
  generator = torch.Generator(device=device).manual_seed(seed)
273
 
274
- # --- Run the Diffusion Pipeline ---
275
- gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h}) with {num_inference_steps} steps...")
276
- logging.info(f"Running pipeline: size={control_image_w}x{control_image_h}, steps={num_inference_steps}")
277
- intermediate_result_image = None # Initialize to None
278
  try:
279
  with torch.inference_mode():
280
- # The actual model inference call
281
  intermediate_result_image = pipe(
282
- prompt="", # No text prompt needed for this upscaler
283
- control_image=control_image,
284
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
285
- num_inference_steps=num_inference_steps,
286
- guidance_scale=0.0, # Guidance scale typically 0 for ControlNet-only tasks
287
  height=control_image_h, # Target height for the model
288
  width=control_image_w, # Target width for the model
289
  generator=generator,
290
- # Can add callback for step progress if needed:
291
- # callback_on_step_end=lambda step, t, latents: progress(step / num_inference_steps)
292
  ).images[0]
293
  logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
294
 
295
  except torch.cuda.OutOfMemoryError as oom_error:
296
- # Handle specific OOM error
297
  logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
298
- gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}. Try reducing the Final Upscale Factor or using a smaller input image.")
299
- if device == 'cuda': torch.cuda.empty_cache() # Try to clear cache
300
- return default_return
301
  except Exception as e:
302
- # Handle other pipeline errors, including potential torch.compile issues
303
  logging.error(f"Error during pipeline execution: {e}", exc_info=True)
304
- # Check if it looks like the known compile error
305
- if "dynamic shape operator" in str(e) or "Unsupported" in str(e.__class__):
306
- gr.Error(f"Inference failed: torch.compile issue encountered ({type(e).__name__}). The model attempted to run unoptimized. If this persists, compilation might need to be disabled in the code.")
307
- else:
308
- gr.Error(f"Inference failed: {e}")
309
- return default_return
310
-
311
- # --- Post-Pipeline Checks and Resizing ---
312
  if not intermediate_result_image:
313
- # Should ideally not happen if no exception was caught, but check anyway
314
- logging.error("Intermediate result image is None after pipeline execution without exception.")
315
- gr.Error("Inference produced no result image unexpectedly.")
316
- return default_return
317
 
 
318
  # Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor
 
319
  if was_input_resized:
320
  # Base final size on the downscaled input that was processed
321
  final_target_w = w_proc * final_upscale_factor
@@ -327,54 +279,48 @@ def infer(
327
  final_target_w = w_original * final_upscale_factor
328
  final_target_h = h_original * final_upscale_factor
329
 
330
- # Perform final resize from intermediate to target size
331
  final_result_image = intermediate_result_image
332
  current_w, current_h = intermediate_result_image.size
333
 
 
334
  if (current_w, current_h) != (final_target_w, final_target_h):
335
  logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
336
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
 
337
  try:
338
- # Ensure target dimensions are valid before resizing
339
  if final_target_w > 0 and final_target_h > 0:
340
- # Use high-quality LANCZOS for downsampling or general resizing
341
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
342
  else:
343
- # Avoid resizing if target dimensions are invalid
344
  gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
345
  final_result_image = intermediate_result_image # Keep intermediate
346
  except Exception as resize_e:
347
- # Handle potential errors during final resize
348
  logging.error(f"Could not resize intermediate image to final size: {resize_e}")
349
  gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
350
  final_result_image = intermediate_result_image # Fallback to intermediate
351
  else:
352
- # No resize needed if intermediate matches final target
353
  logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")
354
 
 
355
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
356
 
357
- # --- Base64 Encoding for API output ---
358
  base64_string = None
359
  if final_result_image:
360
  try:
361
  buffered = io.BytesIO()
362
- # Save as WEBP for potentially smaller size, adjust quality as needed
363
  final_result_image.save(buffered, format="WEBP", quality=90)
364
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
365
- # Format as a data URL
366
  base64_string = f"data:image/webp;base64,{img_str}"
367
  logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
368
  except Exception as enc_err:
369
- # Log if encoding fails but don't stop the process
370
  logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)
371
- base64_string = None # Ensure it's None if encoding failed
372
 
373
- # --- Return results for Gradio ---
374
  return [[original_input_pil, final_result_image], seed, base64_string]
375
 
376
 
377
- # --- Gradio Interface Definition ---
378
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
379
  gr.Markdown(
380
  f"""
@@ -382,15 +328,11 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as d
382
  Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
383
  Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.
384
 
385
- **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation (slower),
386
- then resizes the result to your selected **Final Upscale Factor**.
387
 
388
- **To Speed Up:**
389
- 1. **Reduce `Inference Steps`:** Fewer steps = faster generation (try 10-15).
390
- 2. **(Code Change Needed):** Reduce `INTERNAL_PROCESSING_FACTOR` in the script (e.g., to 3) for direct computation reduction (may lower detail).
391
- 3. `torch.compile` is attempted (`mode="default"`) which might provide speedup after the first run (check logs for success/failure).
392
-
393
- *Limit*: Intermediate processing resolution capped around **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}).
394
  """
395
  )
396
 
@@ -403,24 +345,10 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as d
403
  sources=["upload", "clipboard"],
404
  )
405
  with gr.Column(scale=1):
406
- upscale_factor_slider = gr.Slider(
407
- label="Final Upscale Factor",
408
- info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.",
409
- minimum=1,
410
- maximum=INTERNAL_PROCESSING_FACTOR, # Max limited by internal factor
411
- step=1,
412
- value=min(2, INTERNAL_PROCESSING_FACTOR) # Default to 2x or internal factor if smaller
413
- )
414
- num_inference_steps = gr.Slider(
415
- label="Inference Steps",
416
- info="Fewer steps = faster. Try 10-15.",
417
- minimum=4, maximum=50, step=1, value=15 # Defaulting to 15 for speed
418
- )
419
- controlnet_conditioning_scale = gr.Slider(
420
- label="ControlNet Conditioning Scale",
421
- info="Strength of ControlNet guidance.",
422
- minimum=0.0, maximum=1.5, step=0.05, value=0.6
423
- )
424
  with gr.Row():
425
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
426
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
@@ -430,30 +358,29 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as d
430
  result_slider = ImageSlider(
431
  label="Input / Output Comparison",
432
  type="pil",
433
- interactive=False, # Output only
434
  show_label=True,
435
- position=0.5 # Start slider in the middle
436
  )
437
 
438
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
439
- # Hidden output for API usage
440
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
441
 
442
- # --- Examples ---
443
  example_dir = "examples"
444
  example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
445
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
446
 
447
  if example_paths:
448
  gr.Examples(
449
- # Examples use the new defaults: final factor 2x (or less), steps 15
450
- examples=[ [path, min(2, INTERNAL_PROCESSING_FACTOR), 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
 
451
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
452
- # Map outputs to UI components; base64 ignored by Examples UI
453
- outputs=[result_slider, output_seed],
454
- fn=infer, # Function to call when example is clicked
455
- cache_examples="lazy", # Cache results for examples
456
- label=f"Example Images (Click to Run with {min(2, INTERNAL_PROCESSING_FACTOR)}x Output, 15 Steps)",
457
  run_on_click=True
458
  )
459
  else:
@@ -462,7 +389,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as d
462
  gr.Markdown("---")
463
  gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")
464
 
465
- # Connect button click to the inference function
466
  run_button.click(
467
  fn=infer,
468
  inputs=[
@@ -470,15 +397,12 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as d
470
  randomize_seed,
471
  input_im,
472
  num_inference_steps,
473
- upscale_factor_slider,
474
  controlnet_conditioning_scale,
475
  ],
476
- # Map all return values from infer to the correct output components
477
  outputs=[result_slider, output_seed, api_base64_output],
478
- api_name="upscale" # Define API endpoint name
479
  )
480
 
481
  # Launch the Gradio app
482
- # queue manages concurrent users/requests
483
- # share=False means no public link generated by default
484
  demo.queue(max_size=10).launch(share=False, show_api=True)
 
6
  import base64
7
  import gradio as gr
8
  import numpy as np
9
+ import spaces
10
  import torch
 
 
11
  from diffusers import FluxControlNetModel
12
  from diffusers.pipelines import FluxControlNetPipeline
13
+ from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider
14
+ from PIL import Image, ImageOps # Import ImageOps for exif transpose
15
  from huggingface_hub import snapshot_download
16
 
17
  # --- Setup Logging and Device ---
 
19
  warnings.filterwarnings("ignore")
20
 
21
  css = """
22
+ #col-container {
23
+ margin: 0 auto;
24
+ max-width: 512px; /* Increased max-width slightly for better layout */
25
+ }
26
+ .gradio-container {
27
+ max-width: 900px !important; /* Control overall container width */
28
+ margin: auto !important;
29
+ }
30
  """
31
 
32
  if torch.cuda.is_available():
33
  power_device = "GPU"
34
  device = "cuda"
35
+ torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory
36
  else:
37
  power_device = "CPU"
38
  device = "cpu"
39
+ torch_dtype = torch.float32 # Use float32 for CPU
40
+
41
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
42
 
43
  # --- Authentication and Model Download ---
44
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
45
+
46
+ # Define model IDs
47
  flux_model_id = "black-forest-labs/FLUX.1-dev"
48
  controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
49
+ local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev"
50
  pipe = None
51
 
52
  try:
 
56
  repo_type="model",
57
  ignore_patterns=["*.md", "*.gitattributes"],
58
  local_dir=local_model_dir,
59
+ token=huggingface_token,
60
  )
61
  logging.info(f"Base model downloaded/verified in: {model_path}")
62
 
 
75
  pipe.to(device)
76
  logging.info("Pipeline loaded and moved to device.")
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  except Exception as e:
79
+ logging.error(f"FATAL: Error during model loading: {e}", exc_info=True)
80
+ # --- Simplified Error Handling for Brevity ---
81
+ print(f"FATAL ERROR DURING MODEL LOAD: {e}")
82
+ raise SystemExit(f"Model loading failed: {e}")
83
+
 
84
 
85
  # --- Constants ---
86
  MAX_SEED = 2**32 - 1
87
+ MAX_PIXEL_BUDGET = 1280 * 1280
88
+ # --- NEW: Define the internal factor for quality ---
89
+ INTERNAL_PROCESSING_FACTOR = 4
90
 
91
+ # --- Image Processing Function (Modified) ---
92
  def process_input(input_image):
93
+ """Processes the input image for the pipeline.
94
+ The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR."""
95
  if input_image is None:
96
  raise gr.Error("Input image is missing!")
97
  try:
 
98
  input_image = ImageOps.exif_transpose(input_image)
 
99
  if input_image.mode != 'RGB':
100
  logging.info(f"Converting input image from {input_image.mode} to RGB")
101
  input_image = input_image.convert('RGB')
102
  w, h = input_image.size
103
  except AttributeError:
 
104
  raise gr.Error("Invalid input image format. Please provide a valid image file.")
105
  except Exception as img_err:
 
106
  raise gr.Error(f"Could not process input image: {img_err}")
107
 
108
  w_original, h_original = w, h
109
  if w == 0 or h == 0:
110
  raise gr.Error("Input image has zero width or height.")
111
 
112
+ # Calculate target based on INTERNAL factor for budget check
113
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
114
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
115
  target_pixels_internal = target_w_internal * target_h_internal
116
 
117
  was_resized = False
118
+ input_image_to_process = input_image.copy()
119
 
120
  # Check if the *intermediate* size exceeds the budget
121
  if target_pixels_internal > MAX_PIXEL_BUDGET:
 
122
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
123
  current_input_pixels = w * h
124
 
125
  if current_input_pixels > max_input_pixels:
 
126
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
127
+ input_w_resized = int(w * input_scale_factor)
128
+ input_h_resized = int(h * input_scale_factor)
129
+ # Ensure minimum size of 8x8
130
+ input_w_resized = max(8, input_w_resized)
131
+ input_h_resized = max(8, input_h_resized)
132
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
133
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
134
 
 
141
  f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
142
  )
143
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
144
+ was_resized = True # Flag that original dimensions are lost for precise final scaling
145
 
146
+ # Round processed input dimensions to be multiple of 8
147
  w_proc, h_proc = input_image_to_process.size
148
+ w_final_proc = max(8, w_proc - w_proc % 8)
149
+ h_final_proc = max(8, h_proc - h_proc % 8)
150
 
151
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
152
  logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
 
154
 
155
  return input_image_to_process, w_original, h_original, was_resized
156
 
157
+ # --- Inference Function (Modified) ---
 
 
158
  @spaces.GPU(duration=75)
159
  def infer(
160
  seed,
161
  randomize_seed,
162
  input_image,
163
  num_inference_steps,
164
+ final_upscale_factor, # Renamed for clarity internally
165
  controlnet_conditioning_scale,
166
+ progress=gr.Progress(track_tqdm=True),
167
  ):
 
168
  global pipe
 
 
 
 
 
169
  if pipe is None:
170
  gr.Error("Pipeline not loaded. Cannot perform inference.")
171
+ return [[None, None], 0, None]
172
 
173
+ original_input_pil = input_image # Keep ref for slider
174
 
 
175
  if input_image is None:
176
  gr.Warning("Please provide an input image.")
177
+ return [[None, None], seed or 0, None]
 
 
 
 
178
 
 
179
  if randomize_seed:
180
+ seed = random.randint(0, MAX_SEED)
181
+ seed = int(seed)
 
 
 
182
 
183
+ # Ensure final_upscale_factor is an integer
184
  final_upscale_factor = int(final_upscale_factor)
 
 
 
185
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
186
+ gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). "
187
+ f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.")
188
+ final_upscale_factor = INTERNAL_PROCESSING_FACTOR # Prevent upscaling *beyond* internal processing
189
 
190
  logging.info(
191
+ f"Starting inference with seed: {seed}, "
192
+ f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, "
193
+ f"Final Output Factor: {final_upscale_factor}x, "
194
+ f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}"
195
  )
196
 
 
197
  try:
198
+ # process_input now implicitly uses INTERNAL_PROCESSING_FACTOR for budget checks
199
  processed_input_image, w_original, h_original, was_input_resized = process_input(
200
  input_image
201
  )
202
  except Exception as e:
203
  logging.error(f"Error processing input image: {e}", exc_info=True)
204
  gr.Error(f"Error processing input image: {e}")
205
+ return [[original_input_pil, None], seed, None]
206
 
 
207
  w_proc, h_proc = processed_input_image.size
208
+
209
+ # Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR
210
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
211
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
212
 
213
+ # Clamp control image size if it *still* exceeds budget (e.g., due to rounding or small inputs)
214
+ # This check should technically be redundant if process_input worked correctly, but good failsafe.
215
+ if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add a small margin
216
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
217
  control_image_w = max(8, int(control_image_w * scale_factor))
218
  control_image_h = max(8, int(control_image_h * scale_factor))
 
219
  control_image_w = max(8, control_image_w - control_image_w % 8)
220
  control_image_h = max(8, control_image_h - control_image_h % 8)
221
+ logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.")
222
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
223
 
224
+ logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)")
 
225
  try:
226
+ # Use the processed input image for control, resized to the intermediate size
227
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
228
  except ValueError as resize_err:
229
  logging.error(f"Error resizing processed input to control image: {resize_err}")
230
  gr.Error(f"Failed to prepare control image: {resize_err}")
231
+ return [[original_input_pil, None], seed, None]
232
 
 
233
  generator = torch.Generator(device=device).manual_seed(seed)
234
 
235
+ # --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale ---
236
+ gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...")
237
+ logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}")
238
+ intermediate_result_image = None
239
  try:
240
  with torch.inference_mode():
 
241
  intermediate_result_image = pipe(
242
+ prompt="",
243
+ control_image=control_image, # Control image IS the intermediate size
244
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
245
+ num_inference_steps=int(num_inference_steps),
246
+ guidance_scale=0.0,
247
  height=control_image_h, # Target height for the model
248
  width=control_image_w, # Target width for the model
249
  generator=generator,
 
 
250
  ).images[0]
251
  logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
252
 
253
  except torch.cuda.OutOfMemoryError as oom_error:
 
254
  logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
255
+ gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.")
256
+ if device == 'cuda': torch.cuda.empty_cache()
257
+ return [[original_input_pil, None], seed, None]
258
  except Exception as e:
 
259
  logging.error(f"Error during pipeline execution: {e}", exc_info=True)
260
+ gr.Error(f"Inference failed: {e}")
261
+ return [[original_input_pil, None], seed, None]
262
+
 
 
 
 
 
263
  if not intermediate_result_image:
264
+ logging.error("Intermediate result image is None after pipeline execution.")
265
+ gr.Error("Inference produced no result image.")
266
+ return [[original_input_pil, None], seed, None]
 
267
 
268
+ # --- Final Resizing to User's Desired Scale ---
269
  # Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor
270
+ # If input was resized, we scale the *processed* input size instead, as original is unknown
271
  if was_input_resized:
272
  # Base final size on the downscaled input that was processed
273
  final_target_w = w_proc * final_upscale_factor
 
279
  final_target_w = w_original * final_upscale_factor
280
  final_target_h = h_original * final_upscale_factor
281
 
 
282
  final_result_image = intermediate_result_image
283
  current_w, current_h = intermediate_result_image.size
284
 
285
+ # Only resize if the intermediate size doesn't match the final desired size
286
  if (current_w, current_h) != (final_target_w, final_target_h):
287
  logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
288
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
289
+
290
  try:
 
291
  if final_target_w > 0 and final_target_h > 0:
292
+ # Use LANCZOS for downsampling, it's high quality
293
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
294
  else:
 
295
  gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
296
  final_result_image = intermediate_result_image # Keep intermediate
297
  except Exception as resize_e:
 
298
  logging.error(f"Could not resize intermediate image to final size: {resize_e}")
299
  gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
300
  final_result_image = intermediate_result_image # Fallback to intermediate
301
  else:
 
302
  logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")
303
 
304
+
305
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
306
 
307
+ # --- Base64 Encoding (No change needed here) ---
308
  base64_string = None
309
  if final_result_image:
310
  try:
311
  buffered = io.BytesIO()
 
312
  final_result_image.save(buffered, format="WEBP", quality=90)
313
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
 
314
  base64_string = f"data:image/webp;base64,{img_str}"
315
  logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
316
  except Exception as enc_err:
 
317
  logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)
 
318
 
319
+ # Return original input and the FINAL processed image
320
  return [[original_input_pil, final_result_image], seed, base64_string]
321
 
322
 
323
+ # --- Gradio Interface (Minor Text Updates) ---
324
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
325
  gr.Markdown(
326
  f"""
 
328
  Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
329
  Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.
330
 
331
+ **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation,
332
+ then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution.
333
 
334
+ *Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints.
335
+ The *diffusion process time* is mainly determined by this intermediate size, not the final output size.
 
 
 
 
336
  """
337
  )
338
 
 
345
  sources=["upload", "clipboard"],
346
  )
347
  with gr.Column(scale=1):
348
+ # Renamed slider label for clarity
349
+ upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) # Default to 2x, max is now INTERNAL_PROCESSING_FACTOR
350
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15)
351
+ controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  with gr.Row():
353
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
354
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
 
358
  result_slider = ImageSlider(
359
  label="Input / Output Comparison",
360
  type="pil",
361
+ interactive=False,
362
  show_label=True,
363
+ position=0.5
364
  )
365
 
366
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
 
367
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
368
 
369
+ # --- Examples (Updated default factor if needed) ---
370
  example_dir = "examples"
371
  example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
372
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
373
 
374
  if example_paths:
375
  gr.Examples(
376
+ # Examples now use the new default of 2x for the final factor
377
+ examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
378
+ # Ensure inputs match the order expected by `infer` now
379
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
380
+ outputs=[result_slider, output_seed], # Base64 output ignored by Examples
381
+ fn=infer,
382
+ cache_examples="lazy",
383
+ label="Example Images (Click to Run with 2x Output)",
 
384
  run_on_click=True
385
  )
386
  else:
 
389
  gr.Markdown("---")
390
  gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")
391
 
392
+ # Connect button click
393
  run_button.click(
394
  fn=infer,
395
  inputs=[
 
397
  randomize_seed,
398
  input_im,
399
  num_inference_steps,
400
+ upscale_factor_slider, # Use the slider value here
401
  controlnet_conditioning_scale,
402
  ],
 
403
  outputs=[result_slider, output_seed, api_base64_output],
404
+ api_name="upscale"
405
  )
406
 
407
  # Launch the Gradio app
 
 
408
  demo.queue(max_size=10).launch(share=False, show_api=True)