Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import os | |
| import re | |
| import time | |
| from xml.parsers.expat import model | |
| # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 | |
| if os.environ.get("SPACES_ZERO_GPU") is not None: | |
| import spaces | |
| else: | |
| class spaces: | |
| def GPU(func): | |
| def wrapper(*args, **kwargs): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| from transformers import pipeline as hf_pipeline | |
| import litellm | |
| from tqdm import tqdm | |
| class ModelPrediction: | |
| def __init__(self): | |
| self.model_name2pred_func = { | |
| "gpt-3.5": self._init_model_prediction("gpt-3.5"), | |
| "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), | |
| "o1-mini": self._init_model_prediction("o1-mini"), | |
| "QwQ": self._init_model_prediction("QwQ"), | |
| "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( | |
| "DeepSeek-R1-Distill-Llama-70B" | |
| ), | |
| "llama-8": self._init_model_prediction("llama-8"), | |
| } | |
| self._model_name = None | |
| self._pipeline = None | |
| self.base_prompt= ( | |
| "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" | |
| " Question\n" | |
| "{question}\n" | |
| "Database Schema\n" | |
| "{db_schema}\n" | |
| ) | |
| def pipeline(self): | |
| if self._pipeline is None: | |
| self._pipeline = hf_pipeline( | |
| task="text-generation", | |
| model=self._model_name, | |
| device_map="auto", | |
| ) | |
| return self._pipeline | |
| def _reset_pipeline(self, model_name): | |
| if self._model_name != model_name: | |
| self._model_name = model_name | |
| self._pipeline = None | |
| def _extract_answer_from_pred(pred: str) -> str: | |
| # extract with regex everything is between <answer> and </answer> | |
| matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL) | |
| if matches: | |
| return matches[-1].replace("```", "").replace("sql", "").strip() | |
| else: | |
| matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) | |
| return matches[-1].strip() if matches else pred | |
| def make_prediction(self, question, db_schema, model_name, prompt=None): | |
| if model_name not in self.model_name2pred_func: | |
| raise ValueError( | |
| "Model not supported", | |
| "supported models are", | |
| self.model_name2pred_func.keys(), | |
| ) | |
| prompt = prompt or self.base_prompt | |
| prompt = prompt.format(question=question, db_schema=db_schema) | |
| start_time = time.time() | |
| prediction = self.model_name2pred_func[model_name](prompt) | |
| end_time = time.time() | |
| prediction["response_parsed"] = self._extract_answer_from_pred( | |
| prediction["response"] | |
| ) | |
| prediction['time'] = end_time - start_time | |
| return prediction | |
| def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]: | |
| response = litellm.completion( | |
| model=model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| num_retries=2, | |
| ) | |
| response_text = response["choices"][0]["message"]["content"] | |
| return { | |
| "response": response_text, | |
| "cost": response._hidden_params["response_cost"], | |
| } | |
| def predict_with_hf(self, prompt, model_name): # -> dict[str, Any | float]: | |
| self._reset_pipeline(model_name) | |
| response = self.pipeline([{"role": "user", "content": prompt}])[0][ | |
| "generated_text" | |
| ][-1]["content"] | |
| return {"response": response, "cost": 0.0} | |
| def _init_model_prediction(self, model_name): | |
| predict_fun = self.predict_with_api | |
| if "gpt-3.5" in model_name: | |
| model_name = "openai/gpt-3.5-turbo-0125" | |
| elif "gpt-4o-mini" in model_name: | |
| model_name = "openai/gpt-4o-mini-2024-07-18" | |
| elif "o1-mini" in model_name: | |
| model_name = "openai/o1-mini-2024-09-12" | |
| elif "QwQ" in model_name: | |
| model_name = "together_ai/Qwen/QwQ-32B" | |
| elif "DeepSeek-R1-Distill-Llama-70B" in model_name: | |
| model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" | |
| elif "llama-8" in model_name: | |
| model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| predict_fun = self.predict_with_hf | |
| else: | |
| raise ValueError("Model forbidden") | |
| return partial(predict_fun, model_name=model_name) | |