from functools import partial
import os
import re
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
    import spaces
else:
    class spaces:
        @staticmethod
        def GPU(func):
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper
from transformers import pipeline as hf_pipeline
import litellm
from tqdm import tqdm
class ModelPrediction:
    def __init__(self):
        self.model_name2pred_func = {
            "gpt-3.5": self._init_model_prediction("gpt-3.5"),
            "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
            "o1-mini": self._init_model_prediction("o1-mini"),
            "QwQ": self._init_model_prediction("QwQ"),
            "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
                "DeepSeek-R1-Distill-Llama-70B"
            ),
            "llama-8": self._init_model_prediction("llama-8"),
        }
        self._model_name = None
        self._pipeline = None
        self.base_prompt= (
            "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
            " Question\n"
            "{question}\n"
            "Database Schema\n"
            "{db_schema}\n"
        )
    @property
    def pipeline(self):
        if self._pipeline is None:
            self._pipeline = hf_pipeline(
                task="text-generation",
                model=self._model_name,
                device_map="auto",
            )
        return self._pipeline
    def _reset_pipeline(self, model_name):
        if self._model_name != model_name:
            print("Resetting pipeline with model", model_name)
            self._model_name = model_name
            self._pipeline = None
    @staticmethod
    def _extract_answer_from_pred(pred: str) -> str:
        # extract with regex everything is between  and 
        matches = re.findall(r"(.*?)", pred, re.DOTALL)
        if matches:
            return matches[-1].replace("```", "").replace("sql", "").strip()
        else:
            matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
            return matches[-1].strip() if matches else pred
    def make_prediction(self, question, db_schema,  model_name, prompt=None):
        if model_name not in self.model_name2pred_func:
            raise ValueError(
                "Model not supported",
                "supported models are",
                self.model_name2pred_func.keys(),
            )
        prompt = prompt or self.base_prompt
        prompt = prompt.format(question=question, db_schema=db_schema)
        print(prompt)
        prediction = self.model_name2pred_func[model_name](prompt)
        prediction["response_parsed"] = self._extract_answer_from_pred(
            prediction["response"]
        )
        return prediction
   
    def predict_with_api(self, prompt, model_name):  # -> dict[str, Any | float]:
        response = litellm.completion(
            model=model_name,
            messages=[{"role": "user", "content": prompt}],
            num_retries=2,
        )
        response_text = response["choices"][0]["message"]["content"]
        return {
            "response": response_text,
            "cost": response._hidden_params["response_cost"],
        }
    @spaces.GPU
    def predict_with_hf(self, prompt, model_name):  # -> dict[str, Any | float]:
        self._reset_pipeline(model_name)
        response = self.pipeline([{"role": "user", "content": prompt}])[0][
            "generated_text"
        ][-1]["content"]
        return {"response": response, "cost": 0.0}
    def _init_model_prediction(self, model_name):
        predict_fun = self.predict_with_api
        if "gpt-3.5" in model_name:
            model_name = "openai/gpt-3.5-turbo-0125"
        elif "gpt-4o-mini" in model_name:
            model_name = "openai/gpt-4o-mini-2024-07-18"
        elif "o1-mini" in model_name:
            model_name = "openai/o1-mini-2024-09-12"
        elif "QwQ" in model_name:
            model_name = "together_ai/Qwen/QwQ-32B"
        elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
            model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
        elif "llama-8" in model_name:
            model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
            predict_fun = self.predict_with_hf
        else:
            raise ValueError("Model forbidden")
        return partial(predict_fun, model_name=model_name)