import numpy as np
import torch
import gradio as gr
from PIL import Image
from net.CIDNet import CIDNet
import torchvision.transforms as transforms
import torch.nn.functional as F
import safetensors.torch as sf
import imquality.brisque as brisque
from loss.niqe_utils import *
import spaces
from huggingface_hub import hf_hub_download
import json

def from_pretrained(cls, pretrained_model_name_or_path: str):
    model_id = str(pretrained_model_name_or_path)

    config_file = hf_hub_download(repo_id=model_id, filename="config.json", repo_type="model")
    config = None
    if config_file is not None:
        with open(config_file, "r", encoding="utf-8") as f:
            config = json.load(f)


    model_file = hf_hub_download(repo_id=model_id, filename="model.safetensors", repo_type="model")
    # instance = sf.load_model(cls, model_file, strict=False)
    state_dict  = sf.load_file(model_file)
    cls.load_state_dict(state_dict, strict=False) 

eval_net = CIDNet().cuda()
eval_net.trans.gated = True
eval_net.trans.gated2 = True

@spaces.GPU(duration=120)
def process_image(input_img,score,model_path,gamma=1.0,alpha_s=1.0,alpha_i=1.0):
    if model_path is None:
        return input_img,"Please choose a model weights."
    torch.set_grad_enabled(False)
    from_pretrained(eval_net,"Fediory/HVI-CIDNet-"+model_path)
    # eval_net.load_state_dict(torch.load(os.path.join(directory,model_path), map_location=lambda storage, loc: storage))
    eval_net.eval()
    
    pil2tensor = transforms.Compose([transforms.ToTensor()])
    input = pil2tensor(input_img)
    factor = 8
    h, w = input.shape[1], input.shape[2]
    H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
    padh = H - h if h % factor != 0 else 0
    padw = W - w if w % factor != 0 else 0
    input = F.pad(input.unsqueeze(0), (0,padw,0,padh), 'reflect')
    with torch.no_grad():
        eval_net.trans.alpha_s = alpha_s
        eval_net.trans.alpha = alpha_i
        output = eval_net(input.cuda()**gamma)
    output = torch.clamp(output,0,1).cuda()
    output = output[:, :, :h, :w]
    enhanced_img = transforms.ToPILImage()(output.squeeze(0))
    if score == 'Yes':
        im1 = np.array(enhanced_img)
        score_niqe = calculate_niqe(im1)
        return enhanced_img,score_niqe
    else:
        return enhanced_img,0


directory = "weights"
pth_files = [
    'Generalization',
    'Sony-Total-Dark',
    'LOL-Blur',
    'SICE',
    'LOLv2-real-bestSSIM',
    'LOLv2-real-bestPSNR',
    'LOLv2-syn-wperc',
    'LOLv2-syn-woperc',
    'LOLv1-wperc',
    'LOLv1-woperc'
]


interface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(label="Low-light Image", type="pil"),
        gr.Radio(choices=['Yes','No'],label="Image Score",info="Calculate NIQE, default is \"No\"."),
        gr.Radio(choices=pth_files,label="Model Weights",info="Choose your model. The best models are \"SICE\" and \"Generalization\"."),
        gr.Slider(0.1,5,label="gamma curve",step=0.01,value=1.0, info="Lower is lighter, best range is [0.5,2.5]."),
        gr.Slider(0,2,label="Alpha-s",step=0.01,value=1.0, info="Higher is more saturated."),
        gr.Slider(0.1,2,label="Alpha-i",step=0.01,value=1.0, info="Higher is lighter.")
    ],
    outputs=[
        gr.Image(label="Result", type="pil"),
        gr.Textbox(label="NIQE",info="Lower is better.")
    ],
    title="HVI-CIDNet (Low-Light Image Enhancement)",
    description="The demo of paper \"HVI: A New Color Space for Low-light Image Enhancement\"",
    allow_flagging="never"
)

interface.launch(share=False)