import spaces
import gradio as gr
import numpy as np
import random
import torch
from diffusers import (
    DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline,
    AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
    LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image
)
import gc
import os
import psutil
import threading
from pathlib import Path
import shutil
import time
import glob
from datetime import datetime
from PIL import Image

from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
from onediffusion.models.denoiser.nextdit import NextDiT
from onediffusion.dataset.utils import get_closest_ratio, ASPECT_RATIO_512

#import os
#cache_dir = '/workspace/hf_cache'


# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16  if torch.cuda.is_available() else torch.float32
OUTPUT_DIR = "generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model configurations
MODEL_CONFIGS = {

    "OneDiffusion": {
        "repo_id": "lehduong/OneDiffusion",
        "pipeline_class": OneDiffusionPipeline,
        # "cache_dir" : cache_dir
    }

}

# Dictionary to store model pipelines
pipes = {}
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}


def get_process_memory():
    """Get memory usage of current process in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024


def clear_torch_cache():
    """Clear PyTorch's CUDA cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def remove_cache_dir(model_name):
    """Remove the model's cache directory"""
    cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/',
                                                                                                                    '--')
    if cache_dir.exists():
        shutil.rmtree(cache_dir, ignore_errors=True)


def deep_cleanup(model_name, pipe):
    """Perform deep cleanup of model resources"""
    try:
        # 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
        if hasattr(pipe, 'to'):
            pipe.to('cpu')

        # 2. Delete all model components explicitly
        for attr_name in list(pipe.__dict__.keys()):
            if hasattr(pipe, attr_name):
                delattr(pipe, attr_name)

        # 3. Remove from pipes dictionary
        if model_name in pipes:
            del pipes[model_name]

        # 4. Clear CUDA cache
        clear_torch_cache()

        # 5. Run garbage collection multiple times
        for _ in range(3):
            gc.collect()

        # 6. Remove cached files
        remove_cache_dir(model_name)

        # 7. Additional CUDA cleanup if available
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        # 8. Wait a small amount of time to ensure cleanup
        time.sleep(1)

    except Exception as e:
        print(f"Error during cleanup of {model_name}: {str(e)}")

    finally:
        # Final garbage collection
        gc.collect()
        clear_torch_cache()


def load_pipeline(model_name):
    """Load model pipeline with memory tracking"""
    initial_memory = get_process_memory()
    config = MODEL_CONFIGS[model_name]


    pipe = None
    if model_name == "Kandinsky":
        print("Kandinsky Special")
        pipe = AutoPipelineForText2Image.from_pretrained(
            "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
        )
    else:
        pipe = config["pipeline_class"].from_pretrained(
            config["repo_id"],
            torch_dtype=TORCH_DTYPE,
            # cache_dir=cache_dir
        )
    pipe = pipe.to(DEVICE)

    if hasattr(pipe, 'enable_model_cpu_offload'):
        pipe.enable_model_cpu_offload()

    final_memory = get_process_memory()
    print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")

    return pipe


def save_generated_image(image, model_name, prompt):
    """Save generated image with timestamp and model name"""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    # Create sanitized filename from prompt (first 30 chars)
    prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
    filename = f"{timestamp}_{model_name}_{prompt_part}.png"
    filepath = os.path.join(OUTPUT_DIR, filename)
    image.save(filepath)
    return filepath


def get_generated_images():
    """Get list of generated images with their details"""
    files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
    files.sort(key=os.path.getctime, reverse=True)  # Sort by creation time
    return [
        {
            "path": f,
            "name": os.path.basename(f),
            "date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"),
            "size": f"{os.path.getsize(f) / 1024:.1f} KB"
        }
        for f in files
    ]


def generate_image(
        model_name,
        prompt,
        negative_prompt="",
        seed=42,
        randomize_seed=False,
        width=1024,
        height=1024,
        guidance_scale=4.5,
        num_inference_steps=40,
        progress=gr.Progress(track_tqdm=True)
):
    with model_locks[model_name]:
        try:
            # progress(0, desc=f"Loading {model_name} model...")

            if model_name not in pipes:
                pipes[model_name] = load_pipeline(model_name)

            pipe = pipes[model_name]

            if randomize_seed:
                seed = random.randint(0, MAX_SEED)

            generator = torch.Generator(DEVICE).manual_seed(seed)
            print(f"Generating image with {model_name}...")
            # progress(0.3, desc=f"Generating image with {model_name}...")

            if model_name == "OneDiffusion":
                prompt = "[[text2image]] " + prompt
            
            image = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                width=width,
                height=height,
                generator=generator,
            ).images[0]

            filepath = save_generated_image(image, model_name, prompt)
            print(f"Saved image to: {filepath}")

            # progress(0.9, desc=f"Cleaning up {model_name} resources...")
            # deep_cleanup(model_name, pipe)

            # progress(1.0, desc=f"Generation complete with {model_name}")
            return image, seed

        except Exception as e:
            print(f"Error with {model_name}: {str(e)}")
            if model_name in pipes:
                deep_cleanup(model_name, pipes[model_name])
            raise e


# Gradio Interface
css = """
#col-container {
    margin: 0 auto;
    max-width: 1024px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Multi-Model Image Generation")

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Generate", scale=0, variant="primary")

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                height = gr.Slider(
                    label="Height",
                    minimum=512,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=7.5,
                    step=0.1,
                    value=4.5,
                )
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=40,
                )

        memory_indicator = gr.Markdown("Current memory usage: 0 GB")

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Tabs() as tabs:
                    results = {}
                    seeds = {}
                    for model_name in MODEL_CONFIGS.keys():
                        with gr.Tab(model_name):
                            results[model_name] = gr.Image(label=f"{model_name} Result")
                            seeds[model_name] = gr.Number(label="Seed used", visible=True)
                with gr.Column(scale=1):
                    gr.Markdown("### Generated Images")
                    file_gallery = gr.Gallery(
                        label="Generated Images",
                        show_label=False,
                        elem_id="file_gallery",
                        columns=3,
                        height=800,
                        visible=True
                    )
                    refresh_button = gr.Button("Refresh Gallery")




    def update_gallery():
        """Update the file gallery"""
        files = get_generated_images()
        return [
            (f["path"], f"{f['name']}\n{f['date']}")
            for f in files
        ]


    @spaces.GPU(duration=400)
    def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
                     progress=gr.Progress()):
        outputs = [None] * (len(MODEL_CONFIGS) * 2)
        for idx, model_name in enumerate(MODEL_CONFIGS.keys()):
            try:
                # Display progress for the specific model
                # progress(0, desc=f"Starting generation for {model_name}...")
                print(f"IMAGE GENERATING {model_name} ")
                image, used_seed = generate_image(
                    model_name, prompt, negative_prompt, seed,
                    randomize_seed, width, height, guidance_scale,
                    num_inference_steps, progress
                )
                print(f"IMAGE GENERATIED {model_name} ")
                # Update the respective model's tab with the generated image
                # results[model_name].update(image)
                # seeds[model_name].update(used_seed)
                outputs[idx * 2] = image  # Image slot
                outputs[idx * 2 + 1] = seed  # Seed slot
                # outputs.extend([image, used_seed])
                # Add intermediate results to progress * (len(all_outputs) - len(all_outputs))
                print("YELID")
                yield outputs + [None]


            except Exception as e:
                print(f"Error generating with {model_name}: {str(e)}")
                outputs[idx * 2] = None
                outputs[idx * 2 + 1] = None

        # Update the gallery after generation
        gallery_images = update_gallery()
        # file_gallery.update(value=gallery_images)
        return outputs


    output_components = []
    for model_name in MODEL_CONFIGS.keys():
        output_components.extend([results[model_name], seeds[model_name]])

    run_button.click(
        fn=generate_all,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
        ],
        outputs=output_components,
    )

    refresh_button.click(
        fn=update_gallery,
        inputs=[],
        outputs=[file_gallery],
    )

    demo.load(
        fn=update_gallery,
        inputs=[],
        outputs=[file_gallery],
    )

if __name__ == "__main__":
    demo.launch(server_name='0.0.0.0')