import operator

import datasets
import pandas as pd
from huggingface_hub import HfApi
from ragatouille import RAGPretrainedModel

api = HfApi()

INDEX_DIR_PATH = ".ragatouille/colbert/indexes/ICLR2024-papers-abstract-index/"
api.snapshot_download(
    repo_id="ICLR2024/ICLR2024-papers-abstract-index",
    repo_type="dataset",
    local_dir=INDEX_DIR_PATH,
)
ABSTRACT_RETRIEVER = RAGPretrainedModel.from_index(INDEX_DIR_PATH)
# Run once to initialize the retriever
ABSTRACT_RETRIEVER.search("LLM")


class PaperList:
    COLUMN_INFO = (
        ["Title", "str"],
        ["Authors", "str"],
        ["Type", "str"],
        ["Paper page", "markdown"],
        ["👍", "number"],
        ["💬", "number"],
        ["OpenReview", "markdown"],
        ["Project page", "markdown"],
        ["GitHub", "markdown"],
        ["Spaces", "markdown"],
        ["Models", "markdown"],
        ["Datasets", "markdown"],
        ["claimed", "markdown"],
    )

    def __init__(self) -> None:
        self.df_raw = self.get_df()
        self.df_prettified = self.prettify(self.df_raw)

    @staticmethod
    def get_df() -> pd.DataFrame:
        df = datasets.load_dataset("ICLR2024/ICLR2024-papers", split="train").to_pandas()
        df = df.merge(
            right=datasets.load_dataset("ICLR2024/ICLR2024-paper-stats", split="train").to_pandas(),
            on="id",
            how="left",
        )
        keys = ["n_authors", "n_linked_authors", "upvotes", "num_comments"]
        df[keys] = df[keys].fillna(-1).astype(int)
        df["paper_page"] = df["arxiv_id"].apply(
            lambda arxiv_id: f"https://huggingface.co/papers/{arxiv_id}" if arxiv_id else ""
        )
        return df

    @staticmethod
    def create_link(text: str, url: str) -> str:
        return f'<a href="{url}" target="_blank">{text}</a>'

    @staticmethod
    def prettify(df: pd.DataFrame) -> pd.DataFrame:
        rows = []
        for _, row in df.iterrows():
            author_linked = "✅" if row.n_linked_authors > 0 else ""
            n_linked_authors = "" if row.n_linked_authors == -1 else row.n_linked_authors
            n_authors = "" if row.n_authors == -1 else row.n_authors
            claimed_paper = "" if n_linked_authors == "" else f"{n_linked_authors}/{n_authors} {author_linked}"
            upvotes = "" if row.upvotes == -1 else row.upvotes
            num_comments = "" if row.num_comments == -1 else row.num_comments

            new_row = {
                "Title": row["title"],
                "Authors": ", ".join(row["authors"]),
                "Type": row["type"],
                "Paper page": PaperList.create_link(row["arxiv_id"], row["paper_page"]),
                "Project page": (
                    PaperList.create_link("Project page", row["project_page"]) if row["project_page"] else ""
                ),
                "👍": upvotes,
                "💬": num_comments,
                "OpenReview": PaperList.create_link("OpenReview", row["OpenReview"]),
                "GitHub": "\n".join([PaperList.create_link("GitHub", url) for url in row["GitHub"]]),
                "Spaces": "\n".join(
                    [
                        PaperList.create_link(repo_id, f"https://huggingface.co/spaces/{repo_id}")
                        for repo_id in row["Space"]
                    ]
                ),
                "Models": "\n".join(
                    [PaperList.create_link(repo_id, f"https://huggingface.co/{repo_id}") for repo_id in row["Model"]]
                ),
                "Datasets": "\n".join(
                    [
                        PaperList.create_link(repo_id, f"https://huggingface.co/datasets/{repo_id}")
                        for repo_id in row["Dataset"]
                    ]
                ),
                "claimed": claimed_paper,
            }
            rows.append(new_row)
        return pd.DataFrame(rows, columns=PaperList.get_column_names())

    @staticmethod
    def get_column_names() -> list[str]:
        return list(map(operator.itemgetter(0), PaperList.COLUMN_INFO))

    def get_column_datatypes(self, column_names: list[str]) -> list[str]:
        mapping = dict(self.COLUMN_INFO)
        return [mapping[name] for name in column_names]

    def search(  # noqa: C901
        self,
        title_search_query: str,
        abstract_search_query: str,
        max_num_to_retrieve: int,
        filter_names: list[str],
        presentation_type: str,
        columns_names: list[str],
    ) -> pd.DataFrame:
        df = self.df_raw.copy()
        # As ragatouille uses str for document_id
        df["id"] = df["id"].astype(str)

        # Filter by title
        df = df[df["title"].str.contains(title_search_query, case=False)]

        # Filter by presentation type
        if presentation_type != "(ALL)":
            df = df[df["type"] == presentation_type]

        if "Paper page" in filter_names:
            df = df[df["paper_page"] != ""]
        if "GitHub" in filter_names:
            df = df[df["GitHub"].apply(len) > 0]
        if "Space" in filter_names:
            df = df[df["Space"].apply(len) > 0]
        if "Model" in filter_names:
            df = df[df["Model"].apply(len) > 0]
        if "Dataset" in filter_names:
            df = df[df["Dataset"].apply(len) > 0]

        # Filter by abstract
        if abstract_search_query:
            results = ABSTRACT_RETRIEVER.search(abstract_search_query, k=max_num_to_retrieve)
            remaining_ids = set(map(str, df["id"]))
            found_id_set = set()
            found_ids = []
            for x in results:
                paper_id = x["document_id"]
                if paper_id not in remaining_ids:
                    continue
                if paper_id in found_id_set:
                    continue
                found_id_set.add(paper_id)
                found_ids.append(paper_id)
            df = df[df["id"].isin(found_ids)].set_index("id").reindex(index=found_ids).reset_index()

        df_prettified = self.prettify(df)
        return df_prettified.loc[:, columns_names]