import os
import cv2
import gradio as gr
import numpy as np
from transformers import DetrForObjectDetection, DetrImageProcessor
import torch

# Function to detect face and neck for placing jewelry
def detect_face_and_neck(image):
    model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-50')
    processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-50')
    
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    
    target_sizes = torch.tensor([image.shape[:2]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
    
    neck_box = None
    face_box = None
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        if score > 0.7:
            if label == 1:  # Person (this can include neck)
                neck_box = box
            elif label == 2:  # Face
                face_box = box

    return face_box, neck_box

# Function to overlay jewelry on the detected regions
def place_jewelry(image, jewelry_image, position):
    x, y, w, h = position
    resized_jewelry = cv2.resize(jewelry_image, (int(w), int(h)))
    
    # Ensure that the image has an alpha channel (RGBA) for blending
    if resized_jewelry.shape[2] == 4:
        # Blending using alpha transparency
        for c in range(0, 3):
            image[y:y+h, x:x+w, c] = resized_jewelry[:, :, c] * (resized_jewelry[:, :, 3] / 255.0) + image[y:y+h, x:x+w, c] * (1.0 - resized_jewelry[:, :, 3] / 255.0)
    else:
        image[y:y+h, x:x+w] = resized_jewelry
    
    return image

# Try-on function for jewelry
def tryon_jewelry(person_img, jewelry_img, jewelry_type):
    # Ensure images are valid
    if person_img is None or jewelry_img is None:
        return None
    
    # Detect face and neck using Hugging Face model
    face_box, neck_box = detect_face_and_neck(person_img)
    
    if jewelry_type == "Necklace" and neck_box is not None:
        # Apply necklace on neck region
        result_img = place_jewelry(person_img, jewelry_img, neck_box)
    elif jewelry_type == "Earrings" and face_box is not None:
        # Assuming ears are part of the face box for simplicity
        result_img = place_jewelry(person_img, jewelry_img, face_box)
    else:
        result_img = person_img  # If no detection, return original image

    return result_img

# Gradio interface setup
css = """
#col-left, #col-mid, #col-right {
    margin: 0 auto;
    max-width: 430px;
}
"""

with gr.Blocks(css=css) as JewelryTryon:
    gr.HTML("<h1>Virtual Jewelry Try-On</h1>")
    
    with gr.Row():
        with gr.Column(elem_id="col-left"):
            imgs = gr.Image(label="Person image", sources='upload', type="numpy")
        with gr.Column(elem_id="col-mid"):
            garm_img = gr.Image(label="Jewelry image", sources='upload', type="numpy")
        with gr.Column(elem_id="col-right"):
            jewelry_type = gr.Dropdown(label="Jewelry Type", choices=['Necklace', 'Earrings', 'Ring'], value="Necklace")
            image_out = gr.Image(label="Result", show_share_button=False)
            run_button = gr.Button(value="Run")

    run_button.click(fn=tryon_jewelry, inputs=[imgs, garm_img, jewelry_type], outputs=image_out)

# Launch Gradio app
JewelryTryon.queue(api_open=False).launch(show_api=False)