Gemini899's picture
Update app.py
ae94738 verified
import logging
import random
import warnings
import os
import io
import base64
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider
from PIL import Image, ImageOps # Import ImageOps for exif transpose
from huggingface_hub import snapshot_download
# --- Setup Logging and Device ---
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")
css = """
#col-container {
margin: 0 auto;
max-width: 512px; /* Increased max-width slightly for better layout */
}
.gradio-container {
max-width: 900px !important; /* Control overall container width */
margin: auto !important;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory
else:
power_device = "CPU"
device = "cpu"
torch_dtype = torch.float32 # Use float32 for CPU
logging.info(f"Selected device: {device} | Data type: {torch_dtype}")
# --- Authentication and Model Download ---
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Define model IDs
flux_model_id = "black-forest-labs/FLUX.1-dev"
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev"
pipe = None
try:
logging.info(f"Downloading base model: {flux_model_id}")
model_path = snapshot_download(
repo_id=flux_model_id,
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir=local_model_dir,
token=huggingface_token,
)
logging.info(f"Base model downloaded/verified in: {model_path}")
logging.info(f"Loading ControlNet model: {controlnet_model_id}")
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model_id, torch_dtype=torch_dtype
).to(device)
logging.info("ControlNet model loaded.")
logging.info("Loading FluxControlNetPipeline...")
pipe = FluxControlNetPipeline.from_pretrained(
model_path,
controlnet=controlnet,
torch_dtype=torch_dtype
)
pipe.to(device)
logging.info("Pipeline loaded and moved to device.")
except Exception as e:
logging.error(f"FATAL: Error during model loading: {e}", exc_info=True)
# --- Simplified Error Handling for Brevity ---
print(f"FATAL ERROR DURING MODEL LOAD: {e}")
raise SystemExit(f"Model loading failed: {e}")
# --- Constants ---
MAX_SEED = 2**32 - 1
MAX_PIXEL_BUDGET = 1280 * 1280
# --- NEW: Define the internal factor for quality ---
INTERNAL_PROCESSING_FACTOR = 4
# --- Image Processing Function (Modified) ---
def process_input(input_image):
"""Processes the input image for the pipeline.
The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR."""
if input_image is None:
raise gr.Error("Input image is missing!")
try:
input_image = ImageOps.exif_transpose(input_image)
if input_image.mode != 'RGB':
logging.info(f"Converting input image from {input_image.mode} to RGB")
input_image = input_image.convert('RGB')
w, h = input_image.size
except AttributeError:
raise gr.Error("Invalid input image format. Please provide a valid image file.")
except Exception as img_err:
raise gr.Error(f"Could not process input image: {img_err}")
w_original, h_original = w, h
if w == 0 or h == 0:
raise gr.Error("Input image has zero width or height.")
# Calculate target based on INTERNAL factor for budget check
target_w_internal = w * INTERNAL_PROCESSING_FACTOR
target_h_internal = h * INTERNAL_PROCESSING_FACTOR
target_pixels_internal = target_w_internal * target_h_internal
was_resized = False
input_image_to_process = input_image.copy()
# Check if the *intermediate* size exceeds the budget
if target_pixels_internal > MAX_PIXEL_BUDGET:
max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
current_input_pixels = w * h
if current_input_pixels > max_input_pixels:
input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
input_w_resized = int(w * input_scale_factor)
input_h_resized = int(h * input_scale_factor)
# Ensure minimum size of 8x8
input_w_resized = max(8, input_w_resized)
input_h_resized = max(8, input_h_resized)
intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR
logging.warning(
f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. "
f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}."
)
gr.Info(
f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} "
f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
)
input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
was_resized = True # Flag that original dimensions are lost for precise final scaling
# Round processed input dimensions to be multiple of 8
w_proc, h_proc = input_image_to_process.size
w_final_proc = max(8, w_proc - w_proc % 8)
h_final_proc = max(8, h_proc - h_proc % 8)
if (w_proc, h_proc) != (w_final_proc, h_final_proc):
logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)
return input_image_to_process, w_original, h_original, was_resized
# --- Inference Function (Modified) ---
@spaces.GPU(duration=75)
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
final_upscale_factor, # Renamed for clarity internally
controlnet_conditioning_scale,
progress=gr.Progress(track_tqdm=True),
):
global pipe
if pipe is None:
gr.Error("Pipeline not loaded. Cannot perform inference.")
return [[None, None], 0, None]
original_input_pil = input_image # Keep ref for slider
if input_image is None:
gr.Warning("Please provide an input image.")
return [[None, None], seed or 0, None]
if randomize_seed:
seed = random.randint(0, MAX_SEED)
seed = int(seed)
# Ensure final_upscale_factor is an integer
final_upscale_factor = int(final_upscale_factor)
if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). "
f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.")
final_upscale_factor = INTERNAL_PROCESSING_FACTOR # Prevent upscaling *beyond* internal processing
logging.info(
f"Starting inference with seed: {seed}, "
f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, "
f"Final Output Factor: {final_upscale_factor}x, "
f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}"
)
try:
# process_input now implicitly uses INTERNAL_PROCESSING_FACTOR for budget checks
processed_input_image, w_original, h_original, was_input_resized = process_input(
input_image
)
except Exception as e:
logging.error(f"Error processing input image: {e}", exc_info=True)
gr.Error(f"Error processing input image: {e}")
return [[original_input_pil, None], seed, None]
w_proc, h_proc = processed_input_image.size
# Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR
control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR
# Clamp control image size if it *still* exceeds budget (e.g., due to rounding or small inputs)
# This check should technically be redundant if process_input worked correctly, but good failsafe.
if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add a small margin
scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
control_image_w = max(8, int(control_image_w * scale_factor))
control_image_h = max(8, int(control_image_h * scale_factor))
control_image_w = max(8, control_image_w - control_image_w % 8)
control_image_h = max(8, control_image_h - control_image_h % 8)
logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.")
gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")
logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)")
try:
# Use the processed input image for control, resized to the intermediate size
control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
except ValueError as resize_err:
logging.error(f"Error resizing processed input to control image: {resize_err}")
gr.Error(f"Failed to prepare control image: {resize_err}")
return [[original_input_pil, None], seed, None]
generator = torch.Generator(device=device).manual_seed(seed)
# --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale ---
gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...")
logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}")
intermediate_result_image = None
try:
with torch.inference_mode():
intermediate_result_image = pipe(
prompt="",
control_image=control_image, # Control image IS the intermediate size
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
num_inference_steps=int(num_inference_steps),
guidance_scale=0.0,
height=control_image_h, # Target height for the model
width=control_image_w, # Target width for the model
generator=generator,
).images[0]
logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")
except torch.cuda.OutOfMemoryError as oom_error:
logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.")
if device == 'cuda': torch.cuda.empty_cache()
return [[original_input_pil, None], seed, None]
except Exception as e:
logging.error(f"Error during pipeline execution: {e}", exc_info=True)
gr.Error(f"Inference failed: {e}")
return [[original_input_pil, None], seed, None]
if not intermediate_result_image:
logging.error("Intermediate result image is None after pipeline execution.")
gr.Error("Inference produced no result image.")
return [[original_input_pil, None], seed, None]
# --- Final Resizing to User's Desired Scale ---
# Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor
# If input was resized, we scale the *processed* input size instead, as original is unknown
if was_input_resized:
# Base final size on the downscaled input that was processed
final_target_w = w_proc * final_upscale_factor
final_target_h = h_proc * final_upscale_factor
logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}")
gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.")
else:
# Base final size on the original input size
final_target_w = w_original * final_upscale_factor
final_target_h = h_original * final_upscale_factor
final_result_image = intermediate_result_image
current_w, current_h = intermediate_result_image.size
# Only resize if the intermediate size doesn't match the final desired size
if (current_w, current_h) != (final_target_w, final_target_h):
logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")
try:
if final_target_w > 0 and final_target_h > 0:
# Use LANCZOS for downsampling, it's high quality
final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
else:
gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
final_result_image = intermediate_result_image # Keep intermediate
except Exception as resize_e:
logging.error(f"Could not resize intermediate image to final size: {resize_e}")
gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
final_result_image = intermediate_result_image # Fallback to intermediate
else:
logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")
logging.info(f"Inference successful. Final output size: {final_result_image.size}")
# --- Base64 Encoding (No change needed here) ---
base64_string = None
if final_result_image:
try:
buffered = io.BytesIO()
final_result_image.save(buffered, format="WEBP", quality=90)
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_string = f"data:image/webp;base64,{img_str}"
logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
except Exception as enc_err:
logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)
# Return original input and the FINAL processed image
return [[original_input_pil, final_result_image], seed, base64_string]
# --- Gradio Interface (Minor Text Updates) ---
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
gr.Markdown(
f"""
# ⚡ Flux.1-dev Upscaler ControlNet ⚡
Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.
**How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation,
then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution.
*Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints.
The *diffusion process time* is mainly determined by this intermediate size, not the final output size.
"""
)
with gr.Row():
with gr.Column(scale=2):
input_im = gr.Image(
label="Input Image",
type="pil",
height=350,
sources=["upload", "clipboard"],
)
with gr.Column(scale=1):
# Renamed slider label for clarity
upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) # Default to 2x, max is now INTERNAL_PROCESSING_FACTOR
num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15)
controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)
with gr.Row():
result_slider = ImageSlider(
label="Input / Output Comparison",
type="pil",
interactive=False,
show_label=True,
position=0.5
)
output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)
# --- Examples (Updated default factor if needed) ---
example_dir = "examples"
example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]
if example_paths:
gr.Examples(
# Examples now use the new default of 2x for the final factor
examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
# Ensure inputs match the order expected by `infer` now
inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
outputs=[result_slider, output_seed], # Base64 output ignored by Examples
fn=infer,
cache_examples="lazy",
label="Example Images (Click to Run with 2x Output)",
run_on_click=True
)
else:
gr.Markdown(f"*No example images found in '{example_dir}' directory.*")
gr.Markdown("---")
gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")
# Connect button click
run_button.click(
fn=infer,
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor_slider, # Use the slider value here
controlnet_conditioning_scale,
],
outputs=[result_slider, output_seed, api_base64_output],
api_name="upscale"
)
# Launch the Gradio app
demo.queue(max_size=10).launch(share=False, show_api=True)