Gemini899 commited on
Commit
f6b6080
·
verified ·
1 Parent(s): 06363d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -281
app.py CHANGED
@@ -1,324 +1,253 @@
 
1
  import logging
2
  import random
3
  import warnings
4
  import os
5
- import io # Ensure io is imported
6
- import base64 # Ensure base64 is imported
7
  import gradio as gr
8
  import numpy as np
9
  import spaces
10
  import torch
 
 
11
  from diffusers import FluxControlNetModel
12
  from diffusers.pipelines import FluxControlNetPipeline
13
- from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider
14
- from PIL import Image, ImageOps # Import ImageOps for exif transpose
15
  from huggingface_hub import snapshot_download
16
 
17
  # --- Setup Logging and Device ---
 
18
  logging.basicConfig(level=logging.INFO)
19
  warnings.filterwarnings("ignore")
20
 
21
  css = """
22
- #col-container {
23
- margin: 0 auto;
24
- max-width: 512px;
25
- }
26
- .gradio-container {
27
- max-width: 900px !important;
28
- margin: auto !important;
29
- }
30
  """
31
 
32
  if torch.cuda.is_available():
33
- power_device = "GPU"
34
- device = "cuda"
35
- torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory
36
  else:
37
- power_device = "CPU"
38
- device = "cpu"
39
- torch_dtype = torch.float32 # Use float32 for CPU
40
-
41
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
42
 
43
  # --- Authentication and Model Download ---
44
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
45
-
46
- # Define model IDs
47
  flux_model_id = "black-forest-labs/FLUX.1-dev"
48
  controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
49
- local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev"
50
  pipe = None
51
 
52
  try:
 
53
  logging.info(f"Downloading base model: {flux_model_id}")
54
- model_path = snapshot_download(
55
- repo_id=flux_model_id,
56
- repo_type="model",
57
- ignore_patterns=["*.md", "*.gitattributes"],
58
- local_dir=local_model_dir,
59
- token=huggingface_token,
60
- )
61
  logging.info(f"Base model downloaded/verified in: {model_path}")
62
 
63
  logging.info(f"Loading ControlNet model: {controlnet_model_id}")
64
- controlnet = FluxControlNetModel.from_pretrained(
65
- controlnet_model_id, torch_dtype=torch_dtype
66
- ).to(device)
67
  logging.info("ControlNet model loaded.")
68
 
69
  logging.info("Loading FluxControlNetPipeline...")
70
- pipe = FluxControlNetPipeline.from_pretrained(
71
- model_path,
72
- controlnet=controlnet,
73
- torch_dtype=torch_dtype
74
- )
75
  pipe.to(device)
76
  logging.info("Pipeline loaded and moved to device.")
77
 
78
  # --- OPTIMIZATION: Attempt torch.compile (PyTorch 2.0+) ---
79
  if device == "cuda" and hasattr(torch, "compile"):
80
- logging.info("Attempting to compile the pipeline transformer with torch.compile...")
 
 
 
 
 
 
 
 
81
  try:
82
- # Modes: 'default', 'reduce-overhead', 'max-autotune'
83
- # 'max-autotune' might give best runtime performance but takes longer to compile initially
84
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
85
- # You could potentially compile other components too, but start with the transformer
86
- # pipe.controlnet = torch.compile(pipe.controlnet, mode="max-autotune", fullgraph=True)
87
  logging.info("Pipeline transformer compiled successfully.")
88
- # Optional: Add a dummy inference run here to trigger compilation during startup
89
- # This avoids the compilation delay on the *first* user request.
90
- # try:
91
- # logging.info("Running dummy inference to finalize compilation...")
92
- # _ = pipe(prompt="", control_image=Image.new('RGB', (64, 64)), height=64*4, width=64*4, num_inference_steps=1, guidance_scale=0.0, output_type="latent")
93
- # logging.info("Dummy inference completed.")
94
- # except Exception as compile_run_e:
95
- # logging.warning(f"Dummy inference after compile failed (might be ok): {compile_run_e}")
96
  except Exception as e:
97
- logging.warning(f"torch.compile failed: {e}. Running unoptimized.")
 
 
98
  else:
99
  logging.info("torch.compile not available or not on CUDA, skipping compilation.")
100
 
101
- # --- OPTIMIZATION: Consider xformers ---
102
- # If torch.compile doesn't provide enough speedup or isn't available,
103
- # you can try installing and enabling xformers.
104
- # 1. Add `xformers` to your requirements.txt or install it (`pip install xformers`).
105
- # 2. Uncomment and add this code block *before* the torch.compile block:
106
- # try:
107
- # import xformers
108
- # pipe.enable_xformers_memory_efficient_attention()
109
- # logging.info("Enabled xformers memory efficient attention.")
110
- # except ImportError:
111
- # logging.info("xformers not installed. Skipping.")
112
- # except Exception as e:
113
- # logging.warning(f"Could not enable xformers: {e}.")
114
 
115
  logging.info("Pipeline ready for inference.")
116
 
117
-
118
  except Exception as e:
119
  logging.error(f"FATAL: Error during model loading or setup: {e}", exc_info=True)
120
- # Simple error display if Gradio Blocks object isn't ready
121
  print(f"FATAL ERROR DURING MODEL LOAD/SETUP: {e}")
122
- # You might want to use the Gradio error block structure here if `gr` is available
123
- # with gr.Blocks() as demo_error: ... etc.
124
  raise SystemExit(f"Model loading/setup failed: {e}")
125
 
126
 
127
  # --- Constants ---
128
  MAX_SEED = 2**32 - 1
129
- MAX_PIXEL_BUDGET = 1280 * 1280 # Max pixels for the *intermediate* high-res image
130
-
131
- # --- SPEED VS QUALITY ---
132
- # INTERNAL_PROCESSING_FACTOR: Determines the scale the diffusion model *targets*
133
- # Higher values (like 4) aim for more detail generation but are slower.
134
- # Lower values (like 3 or 2) will be faster but might produce less detail enhancement.
135
- # You were aiming for 4x quality, so we keep it at 4. Reducing this is a direct speedup trade-off.
136
  INTERNAL_PROCESSING_FACTOR = 4
137
 
138
- # --- Image Processing Function (Uses INTERNAL_PROCESSING_FACTOR for budgeting) ---
 
139
  def process_input(input_image):
140
- """Processes the input image for the pipeline.
141
- The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR."""
142
- if input_image is None:
143
- raise gr.Error("Input image is missing!")
144
  try:
145
  input_image = ImageOps.exif_transpose(input_image)
146
- if input_image.mode != 'RGB':
147
- logging.info(f"Converting input image from {input_image.mode} to RGB")
148
- input_image = input_image.convert('RGB')
149
  w, h = input_image.size
150
- except AttributeError:
151
- raise gr.Error("Invalid input image format. Please provide a valid image file.")
152
- except Exception as img_err:
153
- raise gr.Error(f"Could not process input image: {img_err}")
154
 
155
  w_original, h_original = w, h
156
- if w == 0 or h == 0:
157
- raise gr.Error("Input image has zero width or height.")
158
 
159
- # Calculate target based on INTERNAL factor for budget check
160
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
161
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
162
  target_pixels_internal = target_w_internal * target_h_internal
163
-
164
  was_resized = False
165
  input_image_to_process = input_image.copy()
166
 
167
- # Check if the *intermediate* size exceeds the budget
168
  if target_pixels_internal > MAX_PIXEL_BUDGET:
169
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
170
  current_input_pixels = w * h
171
-
172
  if current_input_pixels > max_input_pixels:
173
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
174
- input_w_resized = int(w * input_scale_factor)
175
- input_h_resized = int(h * input_scale_factor)
176
- input_w_resized = max(8, input_w_resized) # Ensure min size
177
- input_h_resized = max(8, input_h_resized) # Ensure min size
178
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
179
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
180
-
181
- logging.warning(
182
- f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. "
183
- f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}."
184
- )
185
- gr.Info(
186
- f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} "
187
- f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
188
- )
189
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
190
- was_resized = True # Flag that original dimensions are lost for precise final scaling
191
 
192
- # Round processed input dimensions to be multiple of 8
193
  w_proc, h_proc = input_image_to_process.size
194
  w_final_proc = max(8, w_proc - w_proc % 8)
195
  h_final_proc = max(8, h_proc - h_proc % 8)
196
-
197
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
198
- logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
199
  input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)
200
 
201
  return input_image_to_process, w_original, h_original, was_resized
202
 
203
- # --- Inference Function (Runs pipeline at Internal Factor, resizes to Final Factor) ---
204
- @spaces.GPU(duration=70) # Keep GPU decorator
 
205
  def infer(
206
- seed,
207
- randomize_seed,
208
- input_image,
209
- num_inference_steps, # Reducing this is a direct way to speed up (quality trade-off)
210
- final_upscale_factor, # User's desired final output scale
211
- controlnet_conditioning_scale,
212
  progress=gr.Progress(track_tqdm=True),
213
  ):
214
  global pipe
 
 
 
215
  if pipe is None:
216
  gr.Error("Pipeline not loaded. Cannot perform inference.")
217
- return [[None, None], 0, None]
218
 
219
- original_input_pil = input_image
220
 
221
  if input_image is None:
222
  gr.Warning("Please provide an input image.")
223
- return [[None, None], seed or 0, None]
 
 
 
 
224
 
225
- if randomize_seed:
226
- seed = random.randint(0, MAX_SEED)
227
  seed = int(seed)
 
 
 
 
 
228
 
229
- # Ensure factors are integers
230
  final_upscale_factor = int(final_upscale_factor)
231
- num_inference_steps = int(num_inference_steps) # Ensure steps are int
232
 
233
- # Sanity check: final factor shouldn't exceed internal processing factor in this workflow
234
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
235
- gr.Warning(f"Selected final upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). "
236
- f"Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x.")
237
  final_upscale_factor = INTERNAL_PROCESSING_FACTOR
238
 
239
- logging.info(
240
- f"Starting inference with seed: {seed}, "
241
- f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, "
242
- f"Final Output Factor: {final_upscale_factor}x, "
243
- f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}"
244
- )
245
 
246
  try:
247
- processed_input_image, w_original, h_original, was_input_resized = process_input(
248
- input_image
249
- )
250
  except Exception as e:
251
  logging.error(f"Error processing input image: {e}", exc_info=True)
252
  gr.Error(f"Error processing input image: {e}")
253
- return [[original_input_pil, None], seed, None]
254
 
255
  w_proc, h_proc = processed_input_image.size
256
-
257
- # Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR
258
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
259
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
260
 
261
- # Failsafe clamp if budget is still exceeded (should be rare if process_input works)
262
- if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Small margin
263
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
264
- control_image_w = max(8, int(control_image_w * scale_factor))
265
- control_image_h = max(8, int(control_image_h * scale_factor))
266
- control_image_w = max(8, control_image_w - control_image_w % 8)
267
- control_image_h = max(8, control_image_h - control_image_h % 8)
268
- logging.warning(f"Control image dimensions clamped post-processing to {control_image_w}x{control_image_h} to fit budget.")
269
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
270
 
271
- logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)")
272
  try:
273
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
274
  except ValueError as resize_err:
275
- logging.error(f"Error resizing processed input to control image: {resize_err}")
276
  gr.Error(f"Failed to prepare control image: {resize_err}")
277
- return [[original_input_pil, None], seed, None]
278
 
279
  generator = torch.Generator(device=device).manual_seed(seed)
280
 
281
- # --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale ---
282
- gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h}) with {num_inference_steps} steps...")
283
- logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}, steps: {num_inference_steps}")
284
- intermediate_result_image = None
285
  try:
286
  with torch.inference_mode():
 
287
  intermediate_result_image = pipe(
288
- prompt="",
289
- control_image=control_image,
290
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
291
- num_inference_steps=num_inference_steps, # Use the integer value
292
- guidance_scale=0.0,
293
- height=control_image_h,
294
- width=control_image_w,
295
- generator=generator,
296
- # Add progress callback if desired, requires tqdm
297
- # callback_on_step_end = ...
298
  ).images[0]
299
- logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
300
 
 
301
  except torch.cuda.OutOfMemoryError as oom_error:
302
- logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
303
- gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}. Try reducing the Final Upscale Factor or using a smaller input image.")
304
  if device == 'cuda': torch.cuda.empty_cache()
305
- return [[original_input_pil, None], seed, None]
306
- except Exception as e:
307
- logging.error(f"Error during pipeline execution: {e}", exc_info=True)
308
- gr.Error(f"Inference failed: {e}")
309
- return [[original_input_pil, None], seed, None]
310
-
 
 
 
 
 
311
  if not intermediate_result_image:
312
- logging.error("Intermediate result image is None after pipeline execution.")
313
- gr.Error("Inference produced no result image.")
314
- return [[original_input_pil, None], seed, None]
315
 
316
- # --- Final Resizing to User's Desired Scale ---
317
  if was_input_resized:
318
  final_target_w = w_proc * final_upscale_factor
319
  final_target_h = h_proc * final_upscale_factor
320
- logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}")
321
- gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.")
322
  else:
323
  final_target_w = w_original * final_upscale_factor
324
  final_target_h = h_original * final_upscale_factor
@@ -327,143 +256,74 @@ def infer(
327
  current_w, current_h = intermediate_result_image.size
328
 
329
  if (current_w, current_h) != (final_target_w, final_target_h):
330
- logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
331
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
332
  try:
333
  if final_target_w > 0 and final_target_h > 0:
334
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
335
  else:
336
- gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
337
  final_result_image = intermediate_result_image
338
  except Exception as resize_e:
339
- logging.error(f"Could not resize intermediate image to final size: {resize_e}")
340
- gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
341
  final_result_image = intermediate_result_image
342
  else:
343
- logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")
344
 
345
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
346
 
347
- # --- Base64 Encoding ---
348
  base64_string = None
349
  if final_result_image:
350
  try:
351
- buffered = io.BytesIO()
352
- final_result_image.save(buffered, format="WEBP", quality=90) # WEBP is usually smaller
353
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
354
  base64_string = f"data:image/webp;base64,{img_str}"
355
- logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
356
  except Exception as enc_err:
357
- logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)
358
 
 
359
  return [[original_input_pil, final_result_image], seed, base64_string]
360
 
361
-
362
- # --- Gradio Interface ---
363
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
364
- gr.Markdown(
365
- f"""
366
  # ⚡ Flux.1-dev Upscaler ControlNet ⚡
367
- Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
368
- Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.
369
-
370
- **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation (slower),
371
- then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution.
372
-
373
- **To Speed Up:**
374
- 1. **Reduce `Inference Steps`:** Fewer steps = faster generation (potential quality decrease). Try 10-15 instead of 25.
375
- 2. **(Code Change Needed):** Reduce `INTERNAL_PROCESSING_FACTOR` in the script (e.g., to 3). This directly reduces computation but may lower detail enhancement.
376
- 3. `torch.compile` has been enabled (if using PyTorch 2.0+ on GPU) which should provide some speedup after the first run.
377
-
378
- *Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints.
379
- The *diffusion process time* is mainly determined by this intermediate size and the number of steps.
380
- """
381
- )
382
-
383
  with gr.Row():
384
- with gr.Column(scale=2):
385
- input_im = gr.Image(
386
- label="Input Image",
387
- type="pil",
388
- height=350,
389
- sources=["upload", "clipboard"],
390
- )
391
  with gr.Column(scale=1):
392
- # Renamed slider label for clarity
393
- upscale_factor_slider = gr.Slider(
394
- label="Final Upscale Factor",
395
- info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.",
396
- minimum=1,
397
- maximum=INTERNAL_PROCESSING_FACTOR, # Max limited by internal factor
398
- step=1,
399
- value=min(2, INTERNAL_PROCESSING_FACTOR) # Default to 2x or internal factor if smaller
400
- )
401
- # --- SPEED OPTIMIZATION: Reduce default steps ---
402
- num_inference_steps = gr.Slider(
403
- label="Inference Steps",
404
- info="Fewer steps = faster, more steps = potentially higher quality. Try 10-15 for speed.",
405
- minimum=4, maximum=50, step=1, value=15 # Defaulting to 15 instead of 25
406
- )
407
- controlnet_conditioning_scale = gr.Slider(
408
- label="ControlNet Conditioning Scale",
409
- info="Strength of ControlNet guidance.",
410
- minimum=0.0, maximum=1.5, step=0.05, value=0.6
411
- )
412
  with gr.Row():
413
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
414
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
415
  run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)
416
-
417
- with gr.Row():
418
- result_slider = ImageSlider(
419
- label="Input / Output Comparison",
420
- type="pil",
421
- interactive=False,
422
- show_label=True,
423
- position=0.5
424
- )
425
-
426
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
427
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
428
 
429
- # --- Examples (Updated default factor if needed) ---
430
- example_dir = "examples"
431
- example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
432
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
433
-
434
  if example_paths:
435
  gr.Examples(
436
- # Examples use the new defaults: final factor 2x, steps 15
437
  examples=[ [path, min(2, INTERNAL_PROCESSING_FACTOR), 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
438
- # Ensure inputs match the order expected by `infer` now
439
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
440
- outputs=[result_slider, output_seed], # Base64 output ignored by Examples
441
- fn=infer,
442
- cache_examples="lazy",
443
- label=f"Example Images (Click to Run with {min(2, INTERNAL_PROCESSING_FACTOR)}x Output, 15 Steps)",
444
- run_on_click=True
445
- )
446
- else:
447
- gr.Markdown(f"*No example images found in '{example_dir}' directory.*")
448
-
449
- gr.Markdown("---")
450
- gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")
451
-
452
- # Connect button click
453
- run_button.click(
454
- fn=infer,
455
- inputs=[
456
- seed,
457
- randomize_seed,
458
- input_im,
459
- num_inference_steps,
460
- upscale_factor_slider, # Use the slider value here
461
- controlnet_conditioning_scale,
462
- ],
463
- outputs=[result_slider, output_seed, api_base64_output],
464
- api_name="upscale" # Keep API name
465
- )
466
-
467
- # Launch the Gradio app
468
- # Consider increasing queue timeout if compilation adds significant startup time
469
  demo.queue(max_size=10).launch(share=False, show_api=True)
 
1
+ # ---- Imports ----
2
  import logging
3
  import random
4
  import warnings
5
  import os
6
+ import io
7
+ import base64
8
  import gradio as gr
9
  import numpy as np
10
  import spaces
11
  import torch
12
+ # --- Add this if you want to try Solution 2 later ---
13
+ # import torch._dynamo
14
  from diffusers import FluxControlNetModel
15
  from diffusers.pipelines import FluxControlNetPipeline
16
+ from gradio_imageslider import ImageSlider
17
+ from PIL import Image, ImageOps
18
  from huggingface_hub import snapshot_download
19
 
20
  # --- Setup Logging and Device ---
21
+ # ... (rest of setup code remains the same) ...
22
  logging.basicConfig(level=logging.INFO)
23
  warnings.filterwarnings("ignore")
24
 
25
  css = """
26
+ #col-container { margin: 0 auto; max-width: 512px; }
27
+ .gradio-container { max-width: 900px !important; margin: auto !important; }
 
 
 
 
 
 
28
  """
29
 
30
  if torch.cuda.is_available():
31
+ power_device = "GPU"; device = "cuda"; torch_dtype = torch.bfloat16
 
 
32
  else:
33
+ power_device = "CPU"; device = "cpu"; torch_dtype = torch.float32
 
 
 
34
  logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
35
 
36
  # --- Authentication and Model Download ---
37
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
38
  flux_model_id = "black-forest-labs/FLUX.1-dev"
39
  controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
40
+ local_model_dir = flux_model_id.split('/')[-1]
41
  pipe = None
42
 
43
  try:
44
+ # ... (model download code remains the same) ...
45
  logging.info(f"Downloading base model: {flux_model_id}")
46
+ model_path = snapshot_download(repo_id=flux_model_id, repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir=local_model_dir, token=huggingface_token)
 
 
 
 
 
 
47
  logging.info(f"Base model downloaded/verified in: {model_path}")
48
 
49
  logging.info(f"Loading ControlNet model: {controlnet_model_id}")
50
+ controlnet = FluxControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch_dtype).to(device)
 
 
51
  logging.info("ControlNet model loaded.")
52
 
53
  logging.info("Loading FluxControlNetPipeline...")
54
+ pipe = FluxControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, torch_dtype=torch_dtype)
 
 
 
 
55
  pipe.to(device)
56
  logging.info("Pipeline loaded and moved to device.")
57
 
58
  # --- OPTIMIZATION: Attempt torch.compile (PyTorch 2.0+) ---
59
  if device == "cuda" and hasattr(torch, "compile"):
60
+ # --- TRY THIS FIRST: Change mode to 'default' ---
61
+ compile_mode = "default"
62
+ # --- Alternative (Solution 2): Uncomment these lines ---
63
+ # import torch._dynamo
64
+ # torch._dynamo.config.suppress_errors = True
65
+ # compile_mode = "max-autotune" # or "default" even with suppress_errors
66
+ # --- End Alternative ---
67
+
68
+ logging.info(f"Attempting to compile the pipeline transformer with torch.compile (mode='{compile_mode}')...")
69
  try:
70
+ pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True)
 
 
 
 
71
  logging.info("Pipeline transformer compiled successfully.")
72
+ # Optional dummy inference run can go here
 
 
 
 
 
 
 
73
  except Exception as e:
74
+ logging.warning(f"torch.compile failed (mode='{compile_mode}'): {e}. Running unoptimized.")
75
+ # --- Solution 3: If compilation fails consistently, comment out the compile line above ---
76
+ # pipe.transformer = torch.compile(pipe.transformer, mode=compile_mode, fullgraph=True) # <-- Comment this out
77
  else:
78
  logging.info("torch.compile not available or not on CUDA, skipping compilation.")
79
 
80
+ # --- (Optional xformers code would go here if used) ---
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  logging.info("Pipeline ready for inference.")
83
 
 
84
  except Exception as e:
85
  logging.error(f"FATAL: Error during model loading or setup: {e}", exc_info=True)
 
86
  print(f"FATAL ERROR DURING MODEL LOAD/SETUP: {e}")
 
 
87
  raise SystemExit(f"Model loading/setup failed: {e}")
88
 
89
 
90
  # --- Constants ---
91
  MAX_SEED = 2**32 - 1
92
+ MAX_PIXEL_BUDGET = 1280 * 1280
 
 
 
 
 
 
93
  INTERNAL_PROCESSING_FACTOR = 4
94
 
95
+ # --- Image Processing Function (process_input) ---
96
+ # ... (process_input function remains the same) ...
97
  def process_input(input_image):
98
+ if input_image is None: raise gr.Error("Input image is missing!")
 
 
 
99
  try:
100
  input_image = ImageOps.exif_transpose(input_image)
101
+ if input_image.mode != 'RGB': input_image = input_image.convert('RGB')
 
 
102
  w, h = input_image.size
103
+ except AttributeError: raise gr.Error("Invalid input image format.")
104
+ except Exception as img_err: raise gr.Error(f"Could not process input image: {img_err}")
 
 
105
 
106
  w_original, h_original = w, h
107
+ if w == 0 or h == 0: raise gr.Error("Input image has zero dimensions.")
 
108
 
 
109
  target_w_internal = w * INTERNAL_PROCESSING_FACTOR
110
  target_h_internal = h * INTERNAL_PROCESSING_FACTOR
111
  target_pixels_internal = target_w_internal * target_h_internal
 
112
  was_resized = False
113
  input_image_to_process = input_image.copy()
114
 
 
115
  if target_pixels_internal > MAX_PIXEL_BUDGET:
116
  max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
117
  current_input_pixels = w * h
 
118
  if current_input_pixels > max_input_pixels:
119
  input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
120
+ input_w_resized = max(8, int(w * input_scale_factor))
121
+ input_h_resized = max(8, int(h * input_scale_factor))
 
 
122
  intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
123
  intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
124
+ logging.warning(f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate exceeds budget. Resizing input {w}x{h} -> {input_w_resized}x{input_h_resized}.")
125
+ gr.Info(f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} -> model generates ~{int(intermediate_w)}x{int(intermediate_h)}.")
 
 
 
 
 
 
 
126
  input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
127
+ was_resized = True
128
 
 
129
  w_proc, h_proc = input_image_to_process.size
130
  w_final_proc = max(8, w_proc - w_proc % 8)
131
  h_final_proc = max(8, h_proc - h_proc % 8)
 
132
  if (w_proc, h_proc) != (w_final_proc, h_final_proc):
133
+ logging.info(f"Rounding processed input dims {w_proc}x{h_proc} -> {w_final_proc}x{h_final_proc}")
134
  input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)
135
 
136
  return input_image_to_process, w_original, h_original, was_resized
137
 
138
+
139
+ # --- Inference Function (infer) ---
140
+ @spaces.GPU(duration=180)
141
  def infer(
142
+ seed, randomize_seed, input_image, num_inference_steps,
143
+ final_upscale_factor, controlnet_conditioning_scale,
 
 
 
 
144
  progress=gr.Progress(track_tqdm=True),
145
  ):
146
  global pipe
147
+ # --- IMPROVED ERROR HANDLING: Define default return early ---
148
+ default_return = [[input_image, None], int(seed) if seed is not None else 0, None]
149
+
150
  if pipe is None:
151
  gr.Error("Pipeline not loaded. Cannot perform inference.")
152
+ return default_return # Use default return
153
 
154
+ original_input_pil = input_image # Keep ref even if None initially
155
 
156
  if input_image is None:
157
  gr.Warning("Please provide an input image.")
158
+ # Update seed in default return if randomized
159
+ if randomize_seed: seed = random.randint(0, MAX_SEED)
160
+ else: seed = int(seed) if seed is not None else 0
161
+ default_return[1] = seed
162
+ return default_return # Use default return
163
 
164
+ if randomize_seed: seed = random.randint(0, MAX_SEED)
 
165
  seed = int(seed)
166
+ # --- UPDATE DEFAULT RETURN SEED ---
167
+ default_return[1] = seed
168
+ # --- Ensure original image is in the default return ---
169
+ default_return[0][0] = original_input_pil
170
+
171
 
 
172
  final_upscale_factor = int(final_upscale_factor)
173
+ num_inference_steps = int(num_inference_steps)
174
 
 
175
  if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
176
+ gr.Warning(f"Clamping final upscale factor {final_upscale_factor}x to internal {INTERNAL_PROCESSING_FACTOR}x.")
 
177
  final_upscale_factor = INTERNAL_PROCESSING_FACTOR
178
 
179
+ logging.info(f"Starting inference: seed={seed}, internal={INTERNAL_PROCESSING_FACTOR}x, final={final_upscale_factor}x, steps={num_inference_steps}, cnet_scale={controlnet_conditioning_scale}")
 
 
 
 
 
180
 
181
  try:
182
+ processed_input_image, w_original, h_original, was_input_resized = process_input(input_image)
 
 
183
  except Exception as e:
184
  logging.error(f"Error processing input image: {e}", exc_info=True)
185
  gr.Error(f"Error processing input image: {e}")
186
+ return default_return # Use default return with correct seed
187
 
188
  w_proc, h_proc = processed_input_image.size
 
 
189
  control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
190
  control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
191
 
192
+ # Failsafe clamp (remains the same)
193
+ if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05:
194
  scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
195
+ control_image_w = max(8, int(control_image_w * scale_factor)); control_image_w -= control_image_w % 8
196
+ control_image_h = max(8, int(control_image_h * scale_factor)); control_image_h -= control_image_h % 8
197
+ logging.warning(f"Control image dims clamped post-processing: {control_image_w}x{control_image_h}.")
 
 
198
  gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
199
 
200
+ logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h}")
201
  try:
202
  control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
203
  except ValueError as resize_err:
204
+ logging.error(f"Error resizing to control image: {resize_err}")
205
  gr.Error(f"Failed to prepare control image: {resize_err}")
206
+ return default_return # Use default return
207
 
208
  generator = torch.Generator(device=device).manual_seed(seed)
209
 
210
+ gr.Info(f"Generating intermediate image ({control_image_w}x{control_image_h}, {num_inference_steps} steps)...")
211
+ logging.info(f"Running pipeline: size={control_image_w}x{control_image_h}, steps={num_inference_steps}")
212
+ intermediate_result_image = None # Initialize
 
213
  try:
214
  with torch.inference_mode():
215
+ # Progress bar integration can be added here if needed
216
  intermediate_result_image = pipe(
217
+ prompt="", control_image=control_image,
 
218
  controlnet_conditioning_scale=float(controlnet_conditioning_scale),
219
+ num_inference_steps=num_inference_steps, guidance_scale=0.0,
220
+ height=control_image_h, width=control_image_w, generator=generator,
 
 
 
 
 
221
  ).images[0]
222
+ logging.info(f"Pipeline finished. Intermediate size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
223
 
224
+ # --- Catch specific errors if needed, otherwise general Exception ---
225
  except torch.cuda.OutOfMemoryError as oom_error:
226
+ logging.error(f"OOM during pipeline: {oom_error}", exc_info=True)
227
+ gr.Error(f"OOM generating {control_image_w}x{control_image_h}. Try smaller input/factor.")
228
  if device == 'cuda': torch.cuda.empty_cache()
229
+ return default_return # Use default return
230
+ except Exception as e: # Catches the torch.compile error too
231
+ logging.error(f"Error during pipeline execution: {e}", exc_info=True) # Log full traceback
232
+ # Provide a more specific error message if it's the known compile issue
233
+ if "dynamic shape operator" in str(e) or "Unsupported" in str(e.__class__):
234
+ gr.Error(f"Inference failed: torch.compile issue encountered ({type(e).__name__}). Try restarting the Space or disabling compilation if persistent.")
235
+ else:
236
+ gr.Error(f"Inference failed: {e}")
237
+ return default_return # Use default return
238
+
239
+ # --- Check if intermediate image was actually created ---
240
  if not intermediate_result_image:
241
+ logging.error("Intermediate result is None after pipeline (but no exception caught).")
242
+ gr.Error("Inference produced no result image unexpectedly.")
243
+ return default_return # Use default return
244
 
245
+ # --- Final Resizing (remains the same logic) ---
246
  if was_input_resized:
247
  final_target_w = w_proc * final_upscale_factor
248
  final_target_h = h_proc * final_upscale_factor
249
+ logging.warning(f"Input downscaled. Final size based on processed: {w_proc}x{h_proc}*{final_upscale_factor}x -> {final_target_w}x{final_target_h}")
250
+ gr.Info(f"Input downscaled. Final size approx {final_target_w}x{final_target_h}.")
251
  else:
252
  final_target_w = w_original * final_upscale_factor
253
  final_target_h = h_original * final_upscale_factor
 
256
  current_w, current_h = intermediate_result_image.size
257
 
258
  if (current_w, current_h) != (final_target_w, final_target_h):
259
+ logging.info(f"Resizing intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}")
260
  gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
261
  try:
262
  if final_target_w > 0 and final_target_h > 0:
263
  final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
264
  else:
265
+ gr.Warning(f"Invalid final target ({final_target_w}x{final_target_h}). Skipping resize.")
266
  final_result_image = intermediate_result_image
267
  except Exception as resize_e:
268
+ logging.error(f"Could not resize final image: {resize_e}")
269
+ gr.Warning(f"Failed final resize. Returning intermediate {current_w}x{current_h}.")
270
  final_result_image = intermediate_result_image
271
  else:
272
+ logging.info("Intermediate size matches final target. No final resize needed.")
273
 
274
  logging.info(f"Inference successful. Final output size: {final_result_image.size}")
275
 
276
+ # --- Base64 Encoding (remains the same) ---
277
  base64_string = None
278
  if final_result_image:
279
  try:
280
+ buffered = io.BytesIO(); final_result_image.save(buffered, format="WEBP", quality=90)
 
281
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
282
  base64_string = f"data:image/webp;base64,{img_str}"
283
+ logging.info(f"Encoded result to Base64 (len: {len(base64_string)}).")
284
  except Exception as enc_err:
285
+ logging.error(f"Failed Base64 encoding: {enc_err}", exc_info=True)
286
 
287
+ # --- SUCCESS RETURN ---
288
  return [[original_input_pil, final_result_image], seed, base64_string]
289
 
290
+ # --- Gradio Interface (Gradio UI Definition) ---
291
+ # ... (Gradio definition remains the same, ensure inputs/outputs match infer) ...
292
  with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
293
+ gr.Markdown(f"""
 
294
  # ⚡ Flux.1-dev Upscaler ControlNet ⚡
295
+ Upscale images using Flux.1-dev Upscaler ControlNet on **{power_device}**.
296
+ Internal processing at **{INTERNAL_PROCESSING_FACTOR}x** quality, resized to **Final Upscale Factor**.
297
+ **Speed Up:** Reduce `Inference Steps` (try 10-15). `torch.compile` enabled (may fail, see logs).
298
+ *Limit*: ~**{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** intermediate size.
299
+ """)
 
 
 
 
 
 
 
 
 
 
 
300
  with gr.Row():
301
+ with gr.Column(scale=2): input_im = gr.Image(label="Input Image", type="pil", height=350, sources=["upload", "clipboard"])
 
 
 
 
 
 
302
  with gr.Column(scale=1):
303
+ upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size. Internal uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=min(2, INTERNAL_PROCESSING_FACTOR))
304
+ num_inference_steps = gr.Slider(label="Inference Steps", info="Fewer=faster (try 10-15).", minimum=4, maximum=50, step=1, value=15)
305
+ controlnet_conditioning_scale = gr.Slider(label="ControlNet Scale", info="Guidance strength.", minimum=0.0, maximum=1.5, step=0.05, value=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  with gr.Row():
307
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
308
  randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
309
  run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)
310
+ with gr.Row(): result_slider = ImageSlider(label="Input / Output Comparison", type="pil", interactive=False, show_label=True, position=0.5)
 
 
 
 
 
 
 
 
 
311
  output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
312
  api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
313
 
314
+ example_dir = "examples"; example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
 
 
315
  example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
 
316
  if example_paths:
317
  gr.Examples(
 
318
  examples=[ [path, min(2, INTERNAL_PROCESSING_FACTOR), 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
 
319
  inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
320
+ outputs=[result_slider, output_seed], fn=infer, cache_examples="lazy",
321
+ label=f"Examples (Click: {min(2, INTERNAL_PROCESSING_FACTOR)}x Output, 15 Steps)", run_on_click=True)
322
+ else: gr.Markdown(f"*No example images found in '{example_dir}'.*")
323
+ gr.Markdown("---"); gr.Markdown("**Disclaimer:** For illustrative purposes.")
324
+
325
+ run_button.click(fn=infer,
326
+ inputs=[seed, randomize_seed, input_im, num_inference_steps, upscale_factor_slider, controlnet_conditioning_scale],
327
+ outputs=[result_slider, output_seed, api_base64_output], api_name="upscale")
328
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  demo.queue(max_size=10).launch(share=False, show_api=True)