import logging
import random
import warnings
import os
import io
import base64
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider # Ensure this is installed: pip install gradio_imageslider
from PIL import Image, ImageOps # Import ImageOps for exif transpose
from huggingface_hub import snapshot_download

# --- Setup Logging and Device ---
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")

css = """
#col-container {
    margin: 0 auto;
    max-width: 512px; /* Increased max-width slightly for better layout */
}
.gradio-container {
    max-width: 900px !important; /* Control overall container width */
    margin: auto !important;
}
"""

if torch.cuda.is_available():
    power_device = "GPU"
    device = "cuda"
    torch_dtype = torch.bfloat16 # Use bfloat16 for GPU for better performance/memory
else:
    power_device = "CPU"
    device = "cpu"
    torch_dtype = torch.float32 # Use float32 for CPU

logging.info(f"Selected device: {device} | Data type: {torch_dtype}")

# --- Authentication and Model Download ---
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

# Define model IDs
flux_model_id = "black-forest-labs/FLUX.1-dev"
controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
local_model_dir = flux_model_id.split('/')[-1] # e.g., "FLUX.1-dev"
pipe = None

try:
    logging.info(f"Downloading base model: {flux_model_id}")
    model_path = snapshot_download(
        repo_id=flux_model_id,
        repo_type="model",
        ignore_patterns=["*.md", "*.gitattributes"],
        local_dir=local_model_dir,
        token=huggingface_token,
    )
    logging.info(f"Base model downloaded/verified in: {model_path}")

    logging.info(f"Loading ControlNet model: {controlnet_model_id}")
    controlnet = FluxControlNetModel.from_pretrained(
        controlnet_model_id, torch_dtype=torch_dtype
    ).to(device)
    logging.info("ControlNet model loaded.")

    logging.info("Loading FluxControlNetPipeline...")
    pipe = FluxControlNetPipeline.from_pretrained(
        model_path,
        controlnet=controlnet,
        torch_dtype=torch_dtype
    )
    pipe.to(device)
    logging.info("Pipeline loaded and moved to device.")

except Exception as e:
    logging.error(f"FATAL: Error during model loading: {e}", exc_info=True)
    # --- Simplified Error Handling for Brevity ---
    print(f"FATAL ERROR DURING MODEL LOAD: {e}")
    raise SystemExit(f"Model loading failed: {e}")


# --- Constants ---
MAX_SEED = 2**32 - 1
MAX_PIXEL_BUDGET = 1280 * 1280
# --- NEW: Define the internal factor for quality ---
INTERNAL_PROCESSING_FACTOR = 4

# --- Image Processing Function (Modified) ---
def process_input(input_image):
    """Processes the input image for the pipeline.
       The pixel budget check uses the fixed INTERNAL_PROCESSING_FACTOR."""
    if input_image is None:
        raise gr.Error("Input image is missing!")
    try:
        input_image = ImageOps.exif_transpose(input_image)
        if input_image.mode != 'RGB':
            logging.info(f"Converting input image from {input_image.mode} to RGB")
            input_image = input_image.convert('RGB')
        w, h = input_image.size
    except AttributeError:
        raise gr.Error("Invalid input image format. Please provide a valid image file.")
    except Exception as img_err:
         raise gr.Error(f"Could not process input image: {img_err}")

    w_original, h_original = w, h
    if w == 0 or h == 0:
        raise gr.Error("Input image has zero width or height.")

    # Calculate target based on INTERNAL factor for budget check
    target_w_internal = w * INTERNAL_PROCESSING_FACTOR
    target_h_internal = h * INTERNAL_PROCESSING_FACTOR
    target_pixels_internal = target_w_internal * target_h_internal

    was_resized = False
    input_image_to_process = input_image.copy()

    # Check if the *intermediate* size exceeds the budget
    if target_pixels_internal > MAX_PIXEL_BUDGET:
        max_input_pixels = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
        current_input_pixels = w * h

        if current_input_pixels > max_input_pixels:
            input_scale_factor = (max_input_pixels / current_input_pixels) ** 0.5
            input_w_resized = int(w * input_scale_factor)
            input_h_resized = int(h * input_scale_factor)
            # Ensure minimum size of 8x8
            input_w_resized = max(8, input_w_resized)
            input_h_resized = max(8, input_h_resized)
            intermediate_w = input_w_resized * INTERNAL_PROCESSING_FACTOR
            intermediate_h = input_h_resized * INTERNAL_PROCESSING_FACTOR

            logging.warning(
                f"Requested {INTERNAL_PROCESSING_FACTOR}x intermediate output ({target_w_internal}x{target_h_internal}) exceeds budget. "
                f"Resizing input from {w}x{h} to {input_w_resized}x{input_h_resized}."
            )
            gr.Info(
                f"Intermediate {INTERNAL_PROCESSING_FACTOR}x size exceeds budget. Input resized to {input_w_resized}x{input_h_resized} "
                f"-> model generates ~{int(intermediate_w)}x{int(intermediate_h)}."
            )
            input_image_to_process = input_image_to_process.resize((input_w_resized, input_h_resized), Image.Resampling.LANCZOS)
            was_resized = True # Flag that original dimensions are lost for precise final scaling

    # Round processed input dimensions to be multiple of 8
    w_proc, h_proc = input_image_to_process.size
    w_final_proc = max(8, w_proc - w_proc % 8)
    h_final_proc = max(8, h_proc - h_proc % 8)

    if (w_proc, h_proc) != (w_final_proc, h_final_proc):
        logging.info(f"Rounding processed input dimensions from {w_proc}x{h_proc} to {w_final_proc}x{h_final_proc}")
        input_image_to_process = input_image_to_process.resize((w_final_proc, h_final_proc), Image.Resampling.LANCZOS)

    return input_image_to_process, w_original, h_original, was_resized

# --- Inference Function (Modified) ---
@spaces.GPU(duration=75)
def infer(
    seed,
    randomize_seed,
    input_image,
    num_inference_steps,
    final_upscale_factor, # Renamed for clarity internally
    controlnet_conditioning_scale,
    progress=gr.Progress(track_tqdm=True),
):
    global pipe
    if pipe is None:
        gr.Error("Pipeline not loaded. Cannot perform inference.")
        return [[None, None], 0, None]

    original_input_pil = input_image # Keep ref for slider

    if input_image is None:
         gr.Warning("Please provide an input image.")
         return [[None, None], seed or 0, None]

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    seed = int(seed)

    # Ensure final_upscale_factor is an integer
    final_upscale_factor = int(final_upscale_factor)
    if final_upscale_factor > INTERNAL_PROCESSING_FACTOR:
        gr.Warning(f"Selected upscale factor ({final_upscale_factor}x) is larger than internal processing factor ({INTERNAL_PROCESSING_FACTOR}x). "
                   f"Results might not be optimal. Clamping final factor to {INTERNAL_PROCESSING_FACTOR}x for this run.")
        final_upscale_factor = INTERNAL_PROCESSING_FACTOR # Prevent upscaling *beyond* internal processing

    logging.info(
        f"Starting inference with seed: {seed}, "
        f"Internal Processing Factor: {INTERNAL_PROCESSING_FACTOR}x, "
        f"Final Output Factor: {final_upscale_factor}x, "
        f"Steps: {num_inference_steps}, CNet Scale: {controlnet_conditioning_scale}"
    )

    try:
        # process_input now implicitly uses INTERNAL_PROCESSING_FACTOR for budget checks
        processed_input_image, w_original, h_original, was_input_resized = process_input(
            input_image
        )
    except Exception as e:
        logging.error(f"Error processing input image: {e}", exc_info=True)
        gr.Error(f"Error processing input image: {e}")
        return [[original_input_pil, None], seed, None]

    w_proc, h_proc = processed_input_image.size

    # Calculate control image dimensions using INTERNAL_PROCESSING_FACTOR
    control_image_w = w_proc * INTERNAL_PROCESSING_FACTOR
    control_image_h = h_proc * INTERNAL_PROCESSING_FACTOR

    # Clamp control image size if it *still* exceeds budget (e.g., due to rounding or small inputs)
    # This check should technically be redundant if process_input worked correctly, but good failsafe.
    if control_image_w * control_image_h > MAX_PIXEL_BUDGET * 1.05: # Add a small margin
        scale_factor = (MAX_PIXEL_BUDGET / (control_image_w * control_image_h)) ** 0.5
        control_image_w = max(8, int(control_image_w * scale_factor))
        control_image_h = max(8, int(control_image_h * scale_factor))
        control_image_w = max(8, control_image_w - control_image_w % 8)
        control_image_h = max(8, control_image_h - control_image_h % 8)
        logging.warning(f"Control image dimensions clamped to {control_image_w}x{control_image_h} post-processing to fit budget.")
        gr.Warning(f"Control image dimensions further clamped to {control_image_w}x{control_image_h}.")

    logging.info(f"Resizing processed input {w_proc}x{h_proc} to control image {control_image_w}x{control_image_h} (using {INTERNAL_PROCESSING_FACTOR}x factor)")
    try:
        # Use the processed input image for control, resized to the intermediate size
        control_image = processed_input_image.resize((control_image_w, control_image_h), Image.Resampling.LANCZOS)
    except ValueError as resize_err:
         logging.error(f"Error resizing processed input to control image: {resize_err}")
         gr.Error(f"Failed to prepare control image: {resize_err}")
         return [[original_input_pil, None], seed, None]

    generator = torch.Generator(device=device).manual_seed(seed)

    # --- Run the Pipeline at INTERNAL_PROCESSING_FACTOR scale ---
    gr.Info(f"Generating intermediate image at {INTERNAL_PROCESSING_FACTOR}x quality ({control_image_w}x{control_image_h})...")
    logging.info(f"Running pipeline with size: {control_image_w}x{control_image_h}")
    intermediate_result_image = None
    try:
        with torch.inference_mode():
            intermediate_result_image = pipe(
                prompt="",
                control_image=control_image, # Control image IS the intermediate size
                controlnet_conditioning_scale=float(controlnet_conditioning_scale),
                num_inference_steps=int(num_inference_steps),
                guidance_scale=0.0,
                height=control_image_h, # Target height for the model
                width=control_image_w,   # Target width for the model
                generator=generator,
            ).images[0]
        logging.info(f"Pipeline execution finished. Intermediate image size: {intermediate_result_image.size if intermediate_result_image else 'None'}")

    except torch.cuda.OutOfMemoryError as oom_error:
         logging.error(f"CUDA Out of Memory during pipeline execution: {oom_error}", exc_info=True)
         gr.Error(f"Ran out of GPU memory trying to generate intermediate {control_image_w}x{control_image_h}.")
         if device == 'cuda': torch.cuda.empty_cache()
         return [[original_input_pil, None], seed, None]
    except Exception as e:
        logging.error(f"Error during pipeline execution: {e}", exc_info=True)
        gr.Error(f"Inference failed: {e}")
        return [[original_input_pil, None], seed, None]

    if not intermediate_result_image:
         logging.error("Intermediate result image is None after pipeline execution.")
         gr.Error("Inference produced no result image.")
         return [[original_input_pil, None], seed, None]

    # --- Final Resizing to User's Desired Scale ---
    # Calculate final target dimensions based on ORIGINAL input size and FINAL upscale factor
    # If input was resized, we scale the *processed* input size instead, as original is unknown
    if was_input_resized:
        # Base final size on the downscaled input that was processed
        final_target_w = w_proc * final_upscale_factor
        final_target_h = h_proc * final_upscale_factor
        logging.warning(f"Input was downscaled. Final size based on processed input: {w_proc}x{h_proc} * {final_upscale_factor}x -> {final_target_w}x{final_target_h}")
        gr.Info(f"Input was downscaled. Final size target approx {final_target_w}x{final_target_h}.")
    else:
        # Base final size on the original input size
        final_target_w = w_original * final_upscale_factor
        final_target_h = h_original * final_upscale_factor

    final_result_image = intermediate_result_image
    current_w, current_h = intermediate_result_image.size

    # Only resize if the intermediate size doesn't match the final desired size
    if (current_w, current_h) != (final_target_w, final_target_h):
        logging.info(f"Resizing intermediate image from {current_w}x{current_h} to final target size {final_target_w}x{final_target_h} (using {final_upscale_factor}x factor)")
        gr.Info(f"Resizing from intermediate {current_w}x{current_h} to final {final_target_w}x{final_target_h}...")

        try:
            if final_target_w > 0 and final_target_h > 0:
                 # Use LANCZOS for downsampling, it's high quality
                 final_result_image = intermediate_result_image.resize((final_target_w, final_target_h), Image.Resampling.LANCZOS)
            else:
                 gr.Warning(f"Invalid final target dimensions ({final_target_w}x{final_target_h}). Skipping final resize.")
                 final_result_image = intermediate_result_image # Keep intermediate
        except Exception as resize_e:
            logging.error(f"Could not resize intermediate image to final size: {resize_e}")
            gr.Warning(f"Failed to resize to final {final_upscale_factor}x. Returning intermediate {INTERNAL_PROCESSING_FACTOR}x result ({current_w}x{current_h}).")
            final_result_image = intermediate_result_image # Fallback to intermediate
    else:
        logging.info(f"Intermediate size {current_w}x{current_h} matches final target size. No final resize needed.")


    logging.info(f"Inference successful. Final output size: {final_result_image.size}")

    # --- Base64 Encoding (No change needed here) ---
    base64_string = None
    if final_result_image:
        try:
            buffered = io.BytesIO()
            final_result_image.save(buffered, format="WEBP", quality=90)
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            base64_string = f"data:image/webp;base64,{img_str}"
            logging.info(f"Encoded result image to Base64 string (length: {len(base64_string)} chars).")
        except Exception as enc_err:
             logging.error(f"Failed to encode result image to Base64: {enc_err}", exc_info=True)

    # Return original input and the FINAL processed image
    return [[original_input_pil, final_result_image], seed, base64_string]


# --- Gradio Interface (Minor Text Updates) ---
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
    gr.Markdown(
        f"""
    # ⚡ Flux.1-dev Upscaler ControlNet ⚡
    Upscale images using the [Flux.1-dev Upscaler ControlNet](https://huggingface.co/jasperai/Flux.1-dev-Controlnet-Upscaler) model based on [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev).
    Currently running on **{power_device}**. Hardware provided by Hugging Face 🤗.

    **How it works:** This demo uses an internal processing scale of **{INTERNAL_PROCESSING_FACTOR}x** for potentially higher detail generation,
    then resizes the result to your selected **Final Upscale Factor**. This aims for {INTERNAL_PROCESSING_FACTOR}x quality at your desired output resolution.

    *Note*: Intermediate processing resolution is limited to approximately **{MAX_PIXEL_BUDGET/1_000_000:.1f} megapixels** ({int(MAX_PIXEL_BUDGET**0.5)}x{int(MAX_PIXEL_BUDGET**0.5)}) due to resource constraints.
    The *diffusion process time* is mainly determined by this intermediate size, not the final output size.
    """
    )

    with gr.Row():
        with gr.Column(scale=2):
            input_im = gr.Image(
                label="Input Image",
                type="pil",
                height=350,
                sources=["upload", "clipboard"],
            )
        with gr.Column(scale=1):
            # Renamed slider label for clarity
            upscale_factor_slider = gr.Slider(label="Final Upscale Factor", info=f"Output size relative to input. Internal processing uses {INTERNAL_PROCESSING_FACTOR}x quality.", minimum=1, maximum=INTERNAL_PROCESSING_FACTOR, step=1, value=2) # Default to 2x, max is now INTERNAL_PROCESSING_FACTOR
            num_inference_steps = gr.Slider(label="Inference Steps", minimum=4, maximum=50, step=1, value=15)
            controlnet_conditioning_scale = gr.Slider(label="ControlNet Conditioning Scale", info="Strength of ControlNet guidance", minimum=0.0, maximum=1.5, step=0.05, value=0.6)
            with gr.Row():
                 seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
                 randomize_seed = gr.Checkbox(label="Random", value=True, scale=0, min_width=80)
            run_button = gr.Button("⚡ Upscale Image", variant="primary", scale=1)

    with gr.Row():
        result_slider = ImageSlider(
            label="Input / Output Comparison",
            type="pil",
            interactive=False,
            show_label=True,
            position=0.5
            )

    output_seed = gr.Textbox(label="Seed Used", interactive=False, visible=True, scale=1)
    api_base64_output = gr.Textbox(label="API Base64 Output", interactive=False, visible=False)

    # --- Examples (Updated default factor if needed) ---
    example_dir = "examples"
    example_files = ["image_2.jpg", "image_4.jpg", "low_res_face.png", "low_res_landscape.png"]
    example_paths = [os.path.join(example_dir, f) for f in example_files if os.path.exists(os.path.join(example_dir, f))]

    if example_paths:
        gr.Examples(
            # Examples now use the new default of 2x for the final factor
            examples=[ [path, 2, 15, 0.6, random.randint(0,MAX_SEED), True] for path in example_paths ],
            # Ensure inputs match the order expected by `infer` now
            inputs=[ input_im, upscale_factor_slider, num_inference_steps, controlnet_conditioning_scale, seed, randomize_seed, ],
            outputs=[result_slider, output_seed], # Base64 output ignored by Examples
            fn=infer,
            cache_examples="lazy",
            label="Example Images (Click to Run with 2x Output)",
            run_on_click=True
        )
    else:
        gr.Markdown(f"*No example images found in '{example_dir}' directory.*")

    gr.Markdown("---")
    gr.Markdown("**Disclaimer:** Demo for illustrative purposes. Users are responsible for generated content.")

    # Connect button click
    run_button.click(
        fn=infer,
        inputs=[
            seed,
            randomize_seed,
            input_im,
            num_inference_steps,
            upscale_factor_slider, # Use the slider value here
            controlnet_conditioning_scale,
        ],
        outputs=[result_slider, output_seed, api_base64_output],
        api_name="upscale"
    )

# Launch the Gradio app
demo.queue(max_size=10).launch(share=False, show_api=True)