import torch, torchvision, clip, time, math
import matplotlib.pyplot as plt
from model import encoder_image
from sentence import *

##### Get infos & cool facts to display during loadings
with open("infos.txt") as file:
    infos = file.readlines()

##### Get css
with open("style.css") as style:
    css = "<style>"+ ''.join(style.readlines())+"</style>"

##### 'DIFFICULTY SWITCH' EVENT
def switch_difficulty(var_dict, html_loading):

    var_dict["difficulty"] = 1 - var_dict["difficulty"]

    title, infos, new_value = loading(html_loading)
    return var_dict, title, infos, new_value

##### 'LOADING' EVENT
def loading(html_loading=None):
    ### This is just to make sure the content changes, which triggers the .change event which, itself, will launch a new game
    if html_loading == "<div style=\"display:none;\">0</div>": new_value = "<div style=\"display:none;\">1</div>"
    else:                                                      new_value = "<div style=\"display:none;\">0</div>"
    ### Get a random tip
    info = np.random.choice(infos)
    ### Return TITLE, TIP TEXT, NEW HTML CONTENT, CANVAS IMG
    return "<h1 id=\"loading\">βŒ›Loading...</h1>",css+"<div id=\"prediction\"><p id=\"infos\">"+info+"</p></div>",new_value

##### 'NEW GAME' EVENT
def new_game(var_dict,img=None,first_game=False):
    print("\n----------Launching new game!")

    if None is not var_dict:    difficulty = var_dict["difficulty"]
    else:                       difficulty = 1

    var_dict = {
        "start_time":           time.time(),
        "total_time":           0,
        "found_words":          [],
        "target_sentence":      "",
        "guessed_sentence":     "",
        "parts":                [],
        "win":                  0,
        "step":                 0,
        "prev_steps":           [],
        "prev_norm":            float("inf"),
        "tip":                  "",
        "loading":              False,
        "revertedState":        False,
        "difficulty":           difficulty
    }
    target = iniSentence(var_dict,first_game=first_game)
    ### Return TITLE, PREDICTION TEXT, CANVAS IMG, VAR DICT
    return "<h1>"+target+"</h1>", getHTML(var_dict,""), None, var_dict

##### PREDICTION TEXT HTML
def getHTML(var_dict,text,win=0):
    ### Which parts of the sentence have been guessed?
    guessed, not_guessed = "", ""
    text_words           = text.split(" ")
    target_words         = var_dict["target_sentence"].split(" ")
    for i,word in enumerate(text_words):
        if i < len(target_words) and word == target_words[i]: guessed += word + " "
        else:                                                 not_guessed += word + " "
    ### Display prediction
    if win!=1:
        html = "<p><span>"+guessed+"</span>"+not_guessed+"</p>"
    else:
        minutes, seconds  = math.floor(var_dict["total_time"]/60), var_dict["total_time"]%60
        if minutes < 1 and seconds <= 30:   emoji = "πŸ†πŸ˜"
        elif minutes < 1:                   emoji = "πŸ˜„"
        elif minutes < 2:                   emoji = "😐"
        elif minutes < 3:                   emoji = "πŸ˜“"
        else:                               emoji = "😱"
        time_str = "Total time: "+ ((str(minutes)+"m") if minutes>0 else "") + str(seconds)+"s "+emoji
        html     = "<p id=\"win\"><span>"+guessed+"</span><br>"+time_str+"</p>"
    return css+"<div id=\"prediction\">"+html+"</div>"

##### DRAWING PROCESSING & GAME STATE UPDATE
def process_img(var_dict,img,title):
    # Makes sure that start_time is updates for the first game
    if var_dict["start_time"] == -1:
        var_dict["start_time"] = time.time()
    if (None is img):
        return getHTML(var_dict,"",win=0),"<h1>"+var_dict["target_sentence"]+"</h1>",var_dict
    elif (None is not img) and (var_dict["win"] != 1):
        print("-----Processing...")
        part   = var_dict["parts"][var_dict["step"]]
        image = torch.tensor(img).float() / 255

        ### Detect Cancel event
        norm  = torch.norm(image)
        if norm > var_dict["prev_norm"]:
            print("---Cancel Event")
            prevState(var_dict)
        var_dict["prev_norm"] = norm

        ### Image preprocessing --> shape (224,224)
        max_edge = max(image.shape[0],image.shape[1])
        min_edge = min(image.shape[0],image.shape[1])
        square_image  = torch.ones(max_edge,max_edge)
        pad           = math.floor((max_edge - min_edge)/2)
        if max_edge == image.shape[1]: square_image[pad:pad+min_edge,:] = image
        else:                          square_image[:,pad:pad+min_edge] = image
        image = torchvision.transforms.Resize((224,224))(square_image.unsqueeze(0)).repeat(1,3,1,1)

        ### Computing cosine similarities (drawing<->text embeddings)
        with torch.no_grad():
            image_features = encoder_image(image)[0]
            text_features  = torch.tensor(part["embeddings"])
            image_features /= image_features.norm()
            similarities   = torch.matmul(text_features,image_features)
            probs          = torch.nn.Softmax(dim=-1)(similarities)

        ### Sort indexes by similarity
        idxs   = np.argsort(similarities)

        ### Use top-3 preditions
        top3_idxs = idxs[-3:]
        classes   = part["classes"]
        preds     = [classes[idx] for idx in top3_idxs]
        print(f"Top-3 Predictions: {preds}")
        print(f"Top-3 Probabilities: {probs[top3_idxs]}")

        ### Check if win (-1: bad guess, 0:progress=guessed sentence part, 1:win=guessed whole sentence)
        win = updateState(var_dict, preds)
        if win == -1:
            text = preds[-1]
        elif win == 0:
            part = var_dict["parts"][var_dict["step"]]
            text = var_dict["guessed_sentence"] + link_text(part,"something") + " something"
        elif win == 1:
            text = var_dict["guessed_sentence"]
            if var_dict["total_time"] == 0: var_dict["total_time"] = round(time.time() - var_dict["start_time"])
        return getHTML(var_dict,text,var_dict["win"]),"<h1>"+var_dict["target_sentence"]+"</h1>",var_dict
    else:
        return getHTML(var_dict,var_dict["target_sentence"],win=1),"<h1>"+var_dict["target_sentence"]+"</h1>",var_dict