from html import escape
import re
import streamlit as st
import pandas as pd, numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from st_clickable_images import clickable_images

MODEL_NAMES = [
    #    "base-patch32",
    #    "base-patch16",
    #    "large-patch14",
    "large-patch14-336"
]


@st.cache(allow_output_mutation=True)
def load():
    df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
    models = {}
    processors = {}
    embeddings = {}
    for name in MODEL_NAMES:
        models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
        processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
        embeddings[name] = {
            0: np.load(f"embeddings-vit-{name}.npy"),
            1: np.load(f"embeddings2-vit-{name}.npy"),
        }
        for k in [0, 1]:
            embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
                embeddings[name][k], axis=1, keepdims=True
            )
    return models, processors, df, embeddings


models, processors, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}


def compute_text_embeddings(list_of_strings, name):
    inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
    with torch.no_grad():
        result = models[name].get_text_features(**inputs).detach().numpy()
    return result / np.linalg.norm(result, axis=1, keepdims=True)


def image_search(query, corpus, name, n_results=24):
    positive_embeddings = None

    def concatenate_embeddings(e1, e2):
        if e1 is None:
            return e2
        else:
            return np.concatenate((e1, e2), axis=0)

    splitted_query = query.split("EXCLUDING ")
    dot_product = 0
    k = 0 if corpus == "Unsplash" else 1
    if len(splitted_query[0]) > 0:
        positive_queries = splitted_query[0].split(";")
        for positive_query in positive_queries:
            match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
            if match:
                corpus2, idx, remainder = match.groups()
                idx, remainder = int(idx), remainder.strip()
                k2 = 0 if corpus2 == "Unsplash" else 1
                positive_embeddings = concatenate_embeddings(
                    positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
                )
                if len(remainder) > 0:
                    positive_embeddings = concatenate_embeddings(
                        positive_embeddings, compute_text_embeddings([remainder], name)
                    )
            else:
                positive_embeddings = concatenate_embeddings(
                    positive_embeddings, compute_text_embeddings([positive_query], name)
                )
        dot_product = embeddings[name][k] @ positive_embeddings.T
        dot_product = dot_product - np.median(dot_product, axis=0)
        dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
        dot_product = np.min(dot_product, axis=1)

    if len(splitted_query) > 1:
        negative_queries = (" ".join(splitted_query[1:])).split(";")
        negative_embeddings = compute_text_embeddings(negative_queries, name)
        dot_product2 = embeddings[name][k] @ negative_embeddings.T
        dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
        dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
        dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)

    results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
    return [
        (
            df[k].iloc[i]["path"],
            df[k].iloc[i]["tooltip"] + source[k],
            i,
        )
        for i in results
    ]


description = """
# Semantic image search

**Enter your query and hit enter**

*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*

*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
"""

howto = """
- Click on an image to use it as a query and find similar images
- Several queries, including one based on an image, can be combined (use "**;**" as a separator)
- If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
"""

div_style = {
    "display": "flex",
    "justify-content": "center",
    "flex-wrap": "wrap",
}


def main():
    st.markdown(
        """
              <style>
              .block-container{
                max-width: 1200px;
              }
              div.row-widget.stRadio > div{
                flex-direction:row;
                display: flex;
                justify-content: center;
              }
              div.row-widget.stRadio > div > label{
                margin-left: 5px;
                margin-right: 5px;
              }
              .row-widget {
                margin-top: -25px;
              }
              section>div:first-child {
                padding-top: 30px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>""",
        unsafe_allow_html=True,
    )
    st.sidebar.markdown(description)
    with st.sidebar.expander("Advanced use"):
        st.markdown(howto)
    # mode = st.sidebar.selectbox(
    #    "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
    # )

    _, c, _ = st.columns((1, 3, 1))
    if "query" in st.session_state:
        query = c.text_input("", value=st.session_state["query"])
    else:
        query = c.text_input("", value="clouds at sunset")
    corpus = st.radio("", ["Unsplash", "Movies"])

    models_dict = {
        "ViT-B/32 (quicker)": "base-patch32",
        "ViT-B/16 (average)": "base-patch16",
        # "ViT-L/14 (slow)": "large-patch14",
        "ViT-L/14@336px (slower)": "large-patch14-336",
    }

    if False:  # "Comparison" in mode:
        c1, c2 = st.columns((1, 1))
        selection1 = c1.selectbox("", models_dict.keys(), index=0)
        selection2 = c2.selectbox("", models_dict.keys(), index=2)
        name1 = models_dict[selection1]
        name2 = models_dict[selection2]
    else:
        name1 = MODEL_NAMES[-1]

    if len(query) > 0:
        results1 = image_search(query, corpus, name1)
        if False:  # "Comparison" in mode:
            with c1:
                clicked1 = clickable_images(
                    [result[0] for result in results1],
                    titles=[result[1] for result in results1],
                    div_style=div_style,
                    img_style={"margin": "2px", "height": "150px"},
                    key=query + corpus + name1 + "1",
                )
            results2 = image_search(query, corpus, name2)
            with c2:
                clicked2 = clickable_images(
                    [result[0] for result in results2],
                    titles=[result[1] for result in results2],
                    div_style=div_style,
                    img_style={"margin": "2px", "height": "150px"},
                    key=query + corpus + name2 + "2",
                )
        else:
            clicked1 = clickable_images(
                [result[0] for result in results1],
                titles=[result[1] for result in results1],
                div_style=div_style,
                img_style={"margin": "2px", "height": "200px"},
                key=query + corpus + name1 + "1",
            )
            clicked2 = -1

        if clicked2 >= 0 or clicked1 >= 0:
            change_query = False
            if "last_clicked" not in st.session_state:
                change_query = True
            else:
                if max(clicked2, clicked1) != st.session_state["last_clicked"]:
                    change_query = True
            if change_query:
                if clicked1 >= 0:
                    st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
                # elif clicked2 >= 0:
                #    st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
                st.experimental_rerun()


if __name__ == "__main__":
    main()