Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import random | |
import warnings | |
import os | |
import io | |
import base64 | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import FluxControlNetModel | |
from diffusers.pipelines import FluxControlNetPipeline | |
from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider | |
from PIL import Image, ImageOps # Import ImageOps for exif transpose | |
from huggingface_hub import snapshot_download | |
# --- Setup Logging and Device --- | |
logging.basicConfig(level=logging.INFO) | |
warnings.filterwarnings("ignore") | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 512px; /* Increased max-width slightly for better layout */ | |
} | |
.gradio-container { | |
max-width: 900px !important; /* Control overall container width */ | |
margin: auto !important; | |
} | |
""" | |
if torch.cuda.is_available(): | |
power_device = "GPU" | |
device = "cuda" | |
torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory | |
else: | |
power_device = "CPU" | |
device = "cpu" | |
torch_dtype = torch.float32 # Use float32 for CPU | |
logging.info(f"Selected device: {device} | Data type: {torch_dtype}") | |
# --- Authentication and Model Download --- | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
# Define model IDs | |
flux_model_id = "black-forest-labs/FLUX.1-dev" | |
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler" | |
local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev" | |
pipe = None | |
try: | |
logging.info(f"Downloading base model: {flux_model_id}") | |
model_path = snapshot_download( | |
repo_id=flux_model_id, | |
repo_type="model", | |
ignore_patterns=["*.md", "*.gitattributes"], | |
local_dir=local_model_dir, | |
token=huggingface_token, | |
) | |
logging.info(f"Base model downloaded/verified in: {model_path}") | |
logging.info(f"Loading ControlNet model: {controlnet_model_id}") | |
controlnet = FluxControlNetModel.from_pretrained( | |
controlnet_model_id, torch_dtype=torch_dtype | |
).to(device) | |
logging.info("ControlNet model loaded.") | |
logging.info("Loading FluxControlNetPipeline...") | |
pipe = FluxControlNetPipeline.from_pretrained( | |
model_path, | |
controlnet=controlnet, | |
torch_dtype=torch_dtype | |
) | |
pipe.to(device) | |
logging.info("Pipeline loaded and moved to device.") | |
except Exception as e: | |
logging.error(f"FATAL: Error during model loading: {e}", exc_info=True) | |
# --- Simplified Error Handling for Brevity --- | |
print(f"FATAL ERROR DURING MODEL LOAD: {e}") | |
raise SystemExit(f"Model loading failed: {e}") | |
# --- Constants --- | |
MAX_SEED = 2**32 - 1 | |
MAX_PIXEL_BUDGET = 1280 * 1280 | |
# --- NEW: Define the internal factor for quality --- | |
INTERNAL_PROCESSING_FACTOR = 4 | |
# --- Image Processing Function (Modified) --- | |
def process_input(input_image): | |
"""Processes the input image for the pipeline. | |
The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR.""" | |
if input_image is None: | |
raise gr.Error("Input image is missing!") | |
try: | |
input_image = ImageOps.exif_transpose(input_image) | |
if input_image.mode != 'RGB': | |
logging.info(f"Converting input image from {input_image.mode} to RGB") | |
input_image = input_image.convert('RGB') | |
w, h = input_image.size | |
except AttributeError: | |
raise gr.Error("Invalid input image format. Please provide a valid image file.") | |
except Exception as img_err: | |
raise gr.Error(f"Could not process input image: {img_err}") | |
w_original, h_original = w, h | |
if w == 0 or h == 0: | |
raise gr.Error("Input image has zero width or height.") | |
# Calculate target based on INTERNAL factor for budget check | |
target_w_internal = w * INTERNAL_PROCESSING_FACTOR | |
target_h_internal = h * INTERNAL_PROCESSING_FACTOR | |
target_pixels_internal = target_w_internal * target_h_internal | |
was_resized = False | |
input_image_to_process = input_image.copy() | |
# Check if the *intermediate* size exceeds the budget | |
if target_pixels_internal > MAX_PIXEL_BUDGET: | |
max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2) | |
current_input_pixels = w * h | |
if current_input_pixels > max_input_pixels: | |
input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5 | |
input_w_resized = int(w * input_scale_factor) | |
input_h_resized = int(h * input_scale_factor) | |
# Ensure minimum size of 8x8 | |
input_w_resized = max(8, input_w_resized) | |
input_h_resized = max(8, input_h_resized) | |
intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR | |
intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR | |
logging.warning( | |
f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. " | |
f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}." | |
) | |
gr.Info( | |
f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} " | |
f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}." | |
) | |
input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS) | |
was_resized = True # Flag that original dimensions are lost for precise final scaling | |
# Round processed input dimensions to be multiple of 8 | |
w_proc, h_proc = input_image_to_process.size | |
w_final_proc = max(8, w_proc - w_proc % 8) | |
h_final_proc = max(8, h_proc - h_proc % 8) | |
if (w_proc, h_proc) != (w_final_proc, h_final_proc): | |
logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}") | |
input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS) | |
return input_image_to_process, w_original, h_original, was_resized | |
# --- Inference Function (Modified) --- | |
def infer( | |
seed, | |
randomize_seed, | |
input_image, | |
num_inference_steps, | |
final_upscale_factor, # Renamed for clarity internally | |
controlnet_conditioning_scale, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
global pipe | |
if pipe is None: | |
gr.Error("Pipeline not loaded. Cannot perform inference.") | |
return [[None, None], 0, None] | |
original_input_pil = input_image # Keep ref for slider | |
if input_image is None: | |
gr.Warning("Please provide an input image.") | |
return [[None, None], seed or 0, None] | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
seed = int(seed) | |
# Ensure final_upscale_factor is an integer | |
final_upscale_factor = int(final_upscale_factor) | |
if final_upscale_factor > INTERNAL_PROCESSING_FACTOR: | |
gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). " | |
f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.") | |
final_upscale_factor = INTERNAL_PROCESSING_FACTOR # Prevent upscaling *beyond* internal processing | |
logging.info( | |
f"Starting inference with seed: {seed}, " | |
f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, " | |
f"Final Output Factor: {final_upscale_factor}x, " | |
f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}" | |
) | |
try: | |
# process_input now implicitly uses INTERNAL_PROCESSING_FACTOR for budget checks | |
processed_input_image, w_original, h_original, was_input_resized = process_input( | |
input_image | |
) | |
except Exception as e: | |
logging.error(f"Error processing input image: {e}", exc_info=True) | |
gr.Error(f"Error processing input image: {e}") | |
return [[original_input_pil, None], seed, None] | |
w_proc, h_proc = processed_input_image.size | |
# Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR | |
control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR | |
control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR | |
# Clamp control image size if it *still* exceeds budget (e.g., due to rounding or small inputs) | |
# This check should technically be redundant if process_input worked correctly, but good failsafe. | |
if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add a small margin | |
scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5 | |
control_image_w = max(8, int(control_image_w * scale_factor)) | |
control_image_h = max(8, int(control_image_h * scale_factor)) | |
control_image_w = max(8, control_image_w - control_image_w % 8) | |
control_image_h = max(8, control_image_h - control_image_h % 8) | |
logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.") | |
gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.") | |
logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)") | |
try: | |
# Use the processed input image for control, resized to the intermediate size | |
control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS) | |
except ValueError as resize_err: | |
logging.error(f"Error resizing processed input to control image: {resize_err}") | |
gr.Error(f"Failed to prepare control image: {resize_err}") | |
return [[original_input_pil, None], seed, None] | |
generator = torch.Generator(device=device).manual_seed(seed) | |
# --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale --- | |
gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...") | |
logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}") | |
intermediate_result_image = None | |
try: | |
with torch.inference_mode(): | |
intermediate_result_image = pipe( | |
prompt="", | |
control_image=control_image, # Control image IS the intermediate size | |
controlnet_conditioning_scale=float(controlnet_conditioning_scale), | |
num_inference_steps=int(num_inference_steps), | |
guidance_scale=0.0, | |
height=control_image_h, # Target height for the model | |
width=control_image_w, # Target width for the model | |
generator=generator, | |
).images[0] | |
logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}") | |
except torch.cuda.OutOfMemoryError as oom_error: | |
logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True) | |
gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.") | |
if device == 'cuda': torch.cuda.empty_cache() | |
return [[original_input_pil, None], seed, None] | |
except Exception as e: | |
logging.error(f"Error during pipeline execution: {e}", exc_info=True) | |
gr.Error(f"Inference failed: {e}") | |
return [[original_input_pil, None], seed, None] | |
if not intermediate_result_image: | |
logging.error("Intermediate result image is None after pipeline execution.") | |
gr.Error("Inference produced no result image.") | |
return [[original_input_pil, None], seed, None] | |
# --- Final Resizing to User's Desired Scale --- | |
# Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor | |
# If input was resized, we scale the *processed* input size instead, as original is unknown | |
if was_input_resized: | |
# Base final size on the downscaled input that was processed | |
final_target_w = w_proc * final_upscale_factor | |
final_target_h = h_proc * final_upscale_factor | |
logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}") | |
gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.") | |
else: | |
# Base final size on the original input size | |
final_target_w = w_original * final_upscale_factor | |
final_target_h = h_original * final_upscale_factor | |
final_result_image = intermediate_result_image | |
current_w, current_h = intermediate_result_image.size | |
# Only resize if the intermediate size doesn't match the final desired size | |
if (current_w, current_h) != (final_target_w, final_target_h): | |
logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)") | |
gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...") | |
try: | |
if final_target_w > 0 and final_target_h > 0: | |
# Use LANCZOS for downsampling, it's high quality | |
final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS) | |
else: | |
gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.") | |
final_result_image = intermediate_result_image # Keep intermediate | |
except Exception as resize_e: | |
logging.error(f"Could not resize intermediate image to final size: {resize_e}") | |
gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).") | |
final_result_image = intermediate_result_image # Fallback to intermediate | |
else: | |
logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.") | |
logging.info(f"Inference successful. Final output size: {final_result_image.size}") | |
# --- Base64 Encoding (No change needed here) --- | |
base64_string = None | |
if final_result_image: | |
try: | |
buffered = io.BytesIO() | |
final_result_image.save(buffered, format="WEBP", quality=90) | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
base64_string = f"data:image/webp;base64,{img_str}" | |
logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).") | |
except Exception as enc_err: | |
logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True) | |
# Return original input and the FINAL processed image | |
return [[original_input_pil, final_result_image], seed, base64_string] | |
# --- Gradio Interface (Minor Text Updates) --- | |
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo: | |
gr.Markdown( | |
f""" | |
# ⚡ Flux.1-dev Upscaler ControlNet ⚡ | |
Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). | |
Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗. | |
**How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation, | |
then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution. | |
*Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints. | |
The *diffusion process time* is mainly determined by this intermediate size, not the final output size. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
input_im = gr.Image( | |
label="Input Image", | |
type="pil", | |
height=350, | |
sources=["upload", "clipboard"], | |
) | |
with gr.Column(scale=1): | |
# Renamed slider label for clarity | |
upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) # Default to 2x, max is now INTERNAL_PROCESSING_FACTOR | |
num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15) | |
controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6) | |
with gr.Row(): | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80) | |
run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1) | |
with gr.Row(): | |
result_slider = ImageSlider( | |
label="Input / Output Comparison", | |
type="pil", | |
interactive=False, | |
show_label=True, | |
position=0.5 | |
) | |
output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1) | |
api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False) | |
# --- Examples (Updated default factor if needed) --- | |
example_dir = "examples" | |
example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"] | |
example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))] | |
if example_paths: | |
gr.Examples( | |
# Examples now use the new default of 2x for the final factor | |
examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ], | |
# Ensure inputs match the order expected by `infer` now | |
inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ], | |
outputs=[result_slider, output_seed], # Base64 output ignored by Examples | |
fn=infer, | |
cache_examples="lazy", | |
label="Example Images (Click to Run with 2x Output)", | |
run_on_click=True | |
) | |
else: | |
gr.Markdown(f"*No example images found in '{example_dir}' directory.*") | |
gr.Markdown("---") | |
gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.") | |
# Connect button click | |
run_button.click( | |
fn=infer, | |
inputs=[ | |
seed, | |
randomize_seed, | |
input_im, | |
num_inference_steps, | |
upscale_factor_slider, # Use the slider value here | |
controlnet_conditioning_scale, | |
], | |
outputs=[result_slider, output_seed, api_base64_output], | |
api_name="upscale" | |
) | |
# Launch the Gradio app | |
demo.queue(max_size=10).launch(share=False, show_api=True) |