from typing import Any
import gradio as gr
import pandas as pd
import json
import requests
from html.parser import HTMLParser

quants = {
    "Q2_K": 3.35,
    "Q3_K_S": 3.5,
    "Q3_K_M": 3.91,
    "Q3_K_L": 4.27,
    "Q4_0": 4.55,
    "Q4_K_S": 4.58,
    "Q4_K_M": 4.85,
    "Q5_0": 5.54,
    "Q5_K_S": 5.54,
    "Q5_K_M": 5.69,
    "Q6_K": 6.59,
    "Q8_0": 8.5,
}

class SvelteHydratorExtractor(HTMLParser):
    def __init__(self):
        self.data = None
        super().__init__()

    def handle_starttag(self, tag, attrs):
        print("Start tag:", tag)
        for attr in attrs:
            if attr[0] == "data-props":
                self.data = attr[1].replace("&quot:", '"')


def calc_model_size(parameters: int, quant: float) -> int:
    return parameters * quant // 8


def get_model_config(hf_model: str) -> dict[str, Any]:
    config = requests.get(
        f"https://huggingface.co/{hf_model}/raw/main/config.json"
    ).json()
    model_size = 0
    try:
        model_size = requests.get(
            f"https://huggingface.co/{hf_model}/raw/main/model.safetensors.index.json"
        ).json()["metadta"]["total_size"]
    except:
        try:
            model_size = requests.get(
                f"https://huggingface.co/{hf_model}/raw/main/pytorch_model.bin.index.json"
            ).json()["metadta"]["total_size"]
        except:
            model_page = requests.get(
                f"https://huggingface.co/{hf_model}"
            ).text
            param_props_idx = model_page.find('data-target="ModelSafetensorsParams"')
            if param_props_idx != -1:
                param_props_start = model_page.rfind("<div", 0, param_props_idx)
                param_props_end = model_page.find(">", param_props_idx)
                extractor = SvelteHydratorExtractor()
                extractor.feed(model_page[param_props_start:param_props_end + 1])
                model_size = (
                    json.loads(
                        extractor.data
                    )["safetensors"]["total"]
                    * 2
                )
            else:
                param_props_idx = model_page.find('data-target="ModelHeader"')
                param_props_start = model_page.rfind("<div", 0, param_props_idx)
                param_props_end = model_page.find(">", param_props_idx)
                extractor = SvelteHydratorExtractor()
                extractor.feed(model_page[param_props_start:param_props_end + 1])
                model_size = (
                    json.loads(
                        extractor.data
                    )["model"]["safetensors"]["total"]
                    * 2
                )

    # assume fp16 weights
    config["parameters"] = model_size / 2
    return config


def calc_input_buffer_size(model_config, context: int) -> float:
    return 4096 + 2048 * model_config["hidden_size"] + context * 4 + context * 2048


def calc_compute_buffer_size(model_config, context: int) -> float:
    return (
        (context / 1024 * 2 + 0.75) * model_config["num_attention_heads"] * 1024 * 1024
    )


def calc_context_size(model_config, context: int) -> float:
    n_gqa = model_config["num_attention_heads"] / model_config["num_key_value_heads"]
    n_embd_gqa = model_config["hidden_size"] / n_gqa
    n_elements = n_embd_gqa * (model_config["num_hidden_layers"] * context)
    return 2 * n_elements * 2


def calc(model_base, context, quant_size):
    model_config = get_model_config(model_base)
    quant_bpw = 0
    try:
        quant_bpw = float(quant_size)
    except:
        quant_bpw = quants[quant_size]

    model_size = round(
        calc_model_size(model_config["parameters"], quant_bpw) / 1000 / 1000 / 1000, 2
    )
    context_size = round(
        (
            calc_input_buffer_size(model_config, context)
            + calc_context_size(model_config, context)
            + calc_compute_buffer_size(model_config, context)
        )
        / 1000
        / 1000
        / 1000,
        2,
    )

    return model_size, context_size, round(model_size + context_size, 2)


title = "GGUF VRAM Calculator"

with gr.Blocks(title=title, theme=gr.themes.Monochrome()) as app:
    default_model = "mistralai/Mistral-7B-v0.1"
    default_quant = "Q4_K_S"
    default_context = 8192
    default_size = calc(default_model, default_context, default_quant)
    default_model_size = default_size[0]
    default_context_size = default_size[1]

    gr.Markdown(
        f"# {app.title}\n## This space has been superseeded by the [NyxKrage/LLM-Model-VRAM-Calculator](https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator), which has model search built in, and doesn't rely on gradio\nThis is meant only as a guide and is will not be 100% accurate, this also does not account for anything that might be running in the background on your system or CUDA system memory fallback on Windows"
    )
    model = gr.Textbox(
        value=default_model,
        label="Enter Unquantized HF Model Name (e.g. mistralai/Mistral-7B-v0.1)",
    )
    context = gr.Number(
        minimum=1, value=default_context, label="Desired Context Size (Tokens)"
    )
    quant = gr.Dropdown(
        choices=list(quants.keys()),
        value=default_quant,
        allow_custom_value=True,
        label="Enter GGUF Quant (e.g. Q4_K_S) or the specific BPW for other quantization schemes such as exl2 (e.g. 4.5)",
    )
    btn = gr.Button(value="Submit", variant="primary")
    btn.click(
        calc,
        inputs=[
            model,
            context,
            quant,
        ],
        outputs=[
            gr.Number(
                label="Model Size (GB)",
                value=default_size[0],
            ),
            gr.Number(
                label="Context Size (GB)",
                value=default_size[1],
            ),
            gr.Number(
                label="Total Size (GB)",
                value=default_size[2],
            ),
        ],
    )

    app.launch()