import logging import random import warnings import os import io import base64 import gradio as gr import numpy as np import spaces import torch from diffusers import FluxControlNetModel from diffusers.pipelines import FluxControlNetPipeline from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider from PIL import Image, ImageOps # Import ImageOps for exif transpose from huggingface_hub import snapshot_download # --- Setup Logging and Device --- logging.basicConfig(level=logging.INFO) warnings.filterwarnings("ignore") css = """ #col-container { margin: 0 auto; max-width: 512px; /* Increased max-width slightly for better layout */ } .gradio-container { max-width: 900px !important; /* Control overall container width */ margin: auto !important; } """ if torch.cuda.is_available(): power_device = "GPU" device = "cuda" torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory else: power_device = "CPU" device = "cpu" torch_dtype = torch.float32 # Use float32 for CPU logging.info(f"Selected device: {device} | Data type: {torch_dtype}") # --- Authentication and Model Download --- huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Define model IDs flux_model_id = "black-forest-labs/FLUX.1-dev" controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler" local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev" pipe = None try: logging.info(f"Downloading base model: {flux_model_id}") model_path = snapshot_download( repo_id=flux_model_id, repo_type="model", ignore_patterns=["*.md", "*.gitattributes"], local_dir=local_model_dir, token=huggingface_token, ) logging.info(f"Base model downloaded/verified in: {model_path}") logging.info(f"Loading ControlNet model: {controlnet_model_id}") controlnet = FluxControlNetModel.from_pretrained( controlnet_model_id, torch_dtype=torch_dtype ).to(device) logging.info("ControlNet model loaded.") logging.info("Loading FluxControlNetPipeline...") pipe = FluxControlNetPipeline.from_pretrained( model_path, controlnet=controlnet, torch_dtype=torch_dtype ) pipe.to(device) logging.info("Pipeline loaded and moved to device.") except Exception as e: logging.error(f"FATAL: Error during model loading: {e}", exc_info=True) # --- Simplified Error Handling for Brevity --- print(f"FATAL ERROR DURING MODEL LOAD: {e}") raise SystemExit(f"Model loading failed: {e}") # --- Constants --- MAX_SEED = 2**32 - 1 MAX_PIXEL_BUDGET = 1280 * 1280 # --- NEW: Define the internal factor for quality --- INTERNAL_PROCESSING_FACTOR = 4 # --- Image Processing Function (Modified) --- def process_input(input_image): """Processes the input image for the pipeline. The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR.""" if input_image is None: raise gr.Error("Input image is missing!") try: input_image = ImageOps.exif_transpose(input_image) if input_image.mode != 'RGB': logging.info(f"Converting input image from {input_image.mode} to RGB") input_image = input_image.convert('RGB') w, h = input_image.size except AttributeError: raise gr.Error("Invalid input image format. Please provide a valid image file.") except Exception as img_err: raise gr.Error(f"Could not process input image: {img_err}") w_original, h_original = w, h if w == 0 or h == 0: raise gr.Error("Input image has zero width or height.") # Calculate target based on INTERNAL factor for budget check target_w_internal = w * INTERNAL_PROCESSING_FACTOR target_h_internal = h * INTERNAL_PROCESSING_FACTOR target_pixels_internal = target_w_internal * target_h_internal was_resized = False input_image_to_process = input_image.copy() # Check if the *intermediate* size exceeds the budget if target_pixels_internal > MAX_PIXEL_BUDGET: max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2) current_input_pixels = w * h if current_input_pixels > max_input_pixels: input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5 input_w_resized = int(w * input_scale_factor) input_h_resized = int(h * input_scale_factor) # Ensure minimum size of 8x8 input_w_resized = max(8, input_w_resized) input_h_resized = max(8, input_h_resized) intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR logging.warning( f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. " f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}." ) gr.Info( f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} " f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}." ) input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS) was_resized = True # Flag that original dimensions are lost for precise final scaling # Round processed input dimensions to be multiple of 8 w_proc, h_proc = input_image_to_process.size w_final_proc = max(8, w_proc - w_proc % 8) h_final_proc = max(8, h_proc - h_proc % 8) if (w_proc, h_proc) != (w_final_proc, h_final_proc): logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}") input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS) return input_image_to_process, w_original, h_original, was_resized # --- Inference Function (Modified) --- @spaces.GPU(duration=75) def infer( seed, randomize_seed, input_image, num_inference_steps, final_upscale_factor, # Renamed for clarity internally controlnet_conditioning_scale, progress=gr.Progress(track_tqdm=True), ): global pipe if pipe is None: gr.Error("Pipeline not loaded. Cannot perform inference.") return [[None, None], 0, None] original_input_pil = input_image # Keep ref for slider if input_image is None: gr.Warning("Please provide an input image.") return [[None, None], seed or 0, None] if randomize_seed: seed = random.randint(0, MAX_SEED) seed = int(seed) # Ensure final_upscale_factor is an integer final_upscale_factor = int(final_upscale_factor) if final_upscale_factor > INTERNAL_PROCESSING_FACTOR: gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). " f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.") final_upscale_factor = INTERNAL_PROCESSING_FACTOR # Prevent upscaling *beyond* internal processing logging.info( f"Starting inference with seed: {seed}, " f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, " f"Final Output Factor: {final_upscale_factor}x, " f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}" ) try: # process_input now implicitly uses INTERNAL_PROCESSING_FACTOR for budget checks processed_input_image, w_original, h_original, was_input_resized = process_input( input_image ) except Exception as e: logging.error(f"Error processing input image: {e}", exc_info=True) gr.Error(f"Error processing input image: {e}") return [[original_input_pil, None], seed, None] w_proc, h_proc = processed_input_image.size # Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR # Clamp control image size if it *still* exceeds budget (e.g., due to rounding or small inputs) # This check should technically be redundant if process_input worked correctly, but good failsafe. if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add a small margin scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5 control_image_w = max(8, int(control_image_w * scale_factor)) control_image_h = max(8, int(control_image_h * scale_factor)) control_image_w = max(8, control_image_w - control_image_w % 8) control_image_h = max(8, control_image_h - control_image_h % 8) logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.") gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.") logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)") try: # Use the processed input image for control, resized to the intermediate size control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS) except ValueError as resize_err: logging.error(f"Error resizing processed input to control image: {resize_err}") gr.Error(f"Failed to prepare control image: {resize_err}") return [[original_input_pil, None], seed, None] generator = torch.Generator(device=device).manual_seed(seed) # --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale --- gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...") logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}") intermediate_result_image = None try: with torch.inference_mode(): intermediate_result_image = pipe( prompt="", control_image=control_image, # Control image IS the intermediate size controlnet_conditioning_scale=float(controlnet_conditioning_scale), num_inference_steps=int(num_inference_steps), guidance_scale=0.0, height=control_image_h, # Target height for the model width=control_image_w, # Target width for the model generator=generator, ).images[0] logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}") except torch.cuda.OutOfMemoryError as oom_error: logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True) gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.") if device == 'cuda': torch.cuda.empty_cache() return [[original_input_pil, None], seed, None] except Exception as e: logging.error(f"Error during pipeline execution: {e}", exc_info=True) gr.Error(f"Inference failed: {e}") return [[original_input_pil, None], seed, None] if not intermediate_result_image: logging.error("Intermediate result image is None after pipeline execution.") gr.Error("Inference produced no result image.") return [[original_input_pil, None], seed, None] # --- Final Resizing to User's Desired Scale --- # Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor # If input was resized, we scale the *processed* input size instead, as original is unknown if was_input_resized: # Base final size on the downscaled input that was processed final_target_w = w_proc * final_upscale_factor final_target_h = h_proc * final_upscale_factor logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}") gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.") else: # Base final size on the original input size final_target_w = w_original * final_upscale_factor final_target_h = h_original * final_upscale_factor final_result_image = intermediate_result_image current_w, current_h = intermediate_result_image.size # Only resize if the intermediate size doesn't match the final desired size if (current_w, current_h) != (final_target_w, final_target_h): logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)") gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...") try: if final_target_w > 0 and final_target_h > 0: # Use LANCZOS for downsampling, it's high quality final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS) else: gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.") final_result_image = intermediate_result_image # Keep intermediate except Exception as resize_e: logging.error(f"Could not resize intermediate image to final size: {resize_e}") gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).") final_result_image = intermediate_result_image # Fallback to intermediate else: logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.") logging.info(f"Inference successful. Final output size: {final_result_image.size}") # --- Base64 Encoding (No change needed here) --- base64_string = None if final_result_image: try: buffered = io.BytesIO() final_result_image.save(buffered, format="WEBP", quality=90) img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") base64_string = f"data:image/webp;base64,{img_str}" logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).") except Exception as enc_err: logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True) # Return original input and the FINAL processed image return [[original_input_pil, final_result_image], seed, base64_string] # --- Gradio Interface (Minor Text Updates) --- with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo: gr.Markdown( f""" # ⚡ Flux.1-dev Upscaler ControlNet ⚡ Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗. **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation, then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution. *Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints. The *diffusion process time* is mainly determined by this intermediate size, not the final output size. """ ) with gr.Row(): with gr.Column(scale=2): input_im = gr.Image( label="Input Image", type="pil", height=350, sources=["upload", "clipboard"], ) with gr.Column(scale=1): # Renamed slider label for clarity upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) # Default to 2x, max is now INTERNAL_PROCESSING_FACTOR num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15) controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6) with gr.Row(): seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80) run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1) with gr.Row(): result_slider = ImageSlider( label="Input / Output Comparison", type="pil", interactive=False, show_label=True, position=0.5 ) output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1) api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False) # --- Examples (Updated default factor if needed) --- example_dir = "examples" example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"] example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))] if example_paths: gr.Examples( # Examples now use the new default of 2x for the final factor examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ], # Ensure inputs match the order expected by `infer` now inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ], outputs=[result_slider, output_seed], # Base64 output ignored by Examples fn=infer, cache_examples="lazy", label="Example Images (Click to Run with 2x Output)", run_on_click=True ) else: gr.Markdown(f"*No example images found in '{example_dir}' directory.*") gr.Markdown("---") gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.") # Connect button click run_button.click( fn=infer, inputs=[ seed, randomize_seed, input_im, num_inference_steps, upscale_factor_slider, # Use the slider value here controlnet_conditioning_scale, ], outputs=[result_slider, output_seed, api_base64_output], api_name="upscale" ) # Launch the Gradio app demo.queue(max_size=10).launch(share=False, show_api=True)