from functools import partial import os import re import time from xml.parsers.expat import model # https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 if os.environ.get("SPACES_ZERO_GPU") is not None: import spaces else: class spaces: @staticmethod def GPU(func): def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper from transformers import pipeline as hf_pipeline import torch import litellm from tqdm import tqdm import subprocess # https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132 # subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) # pipeline = hf_pipeline( # "text-generation", # model="meta-llama/Meta-Llama-3.1-8B-Instruct", # model_kwargs={"torch_dtype": 'bfloat16'}, # device_map="auto", # ) class ModelPrediction: def __init__(self): self.model_name2pred_func = { "gpt-3.5": self._init_model_prediction("gpt-3.5"), "gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), "llama-70": self._init_model_prediction("llama-70"), "DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( "DeepSeek-R1-Distill-Llama-70B" ), "llama-8": self._init_model_prediction("llama-8"), } self._model_name = None self._pipeline = None self.base_prompt= ( "Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" " Question\n" "{question}\n" "Database Schema\n" "{db_schema}\n" ) self.base_prompt_QA= ( "Return the answer of the following question based on the provided database." " Return your answer as the result of a query executed over the database." " Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n" "Return the answer in answer tag as " " Question\n" "{question}\n" "Database Schema\n" "{db_schema}\n" ) @staticmethod def _extract_answer_from_pred(pred: str) -> str: # extract with regex everything is between and matches = re.findall(r"(.*?)", pred, re.DOTALL) if matches: return matches[-1].replace("```", "").replace("sql", "").strip() else: matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) return matches[-1].strip() if matches else pred def make_prediction(self, question, db_schema, model_name, prompt=None, task='SP'): if model_name not in self.model_name2pred_func: raise ValueError( "Model not supported", "supported models are", self.model_name2pred_func.keys(), ) if task == 'SP': prompt = prompt or self.base_prompt else: prompt = prompt or self.base_prompt_QA prompt = prompt.format(question=question, db_schema=db_schema) start_time = time.time() prediction = self.model_name2pred_func[model_name](prompt) end_time = time.time() prediction["response_parsed"] = self._extract_answer_from_pred( prediction["response"] ) prediction['time'] = end_time - start_time return prediction def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]: response = litellm.completion( model=model_name, messages=[{"role": "user", "content": prompt}], num_retries=2, ) response_text = response["choices"][0]["message"]["content"] return { "response": response_text, "cost": response._hidden_params["response_cost"], } def _init_model_prediction(self, model_name): predict_fun = self.predict_with_api if "gpt-3.5" in model_name: model_name = "openai/gpt-3.5-turbo-0125" elif "gpt-4o-mini" in model_name: model_name = "openai/gpt-4o-mini-2024-07-18" elif "DeepSeek-R1-Distill-Llama-70B" in model_name: model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" elif "llama-8" in model_name: model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" elif "llama-70" in model_name: model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" else: raise ValueError("Model forbidden") return partial(predict_fun, model_name=model_name)