Spaces:
Running
Running
from functools import partial | |
import os | |
import re | |
import time | |
from xml.parsers.expat import model | |
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10 | |
if os.environ.get("SPACES_ZERO_GPU") is not None: | |
import spaces | |
else: | |
class spaces: | |
def GPU(func): | |
def wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
return wrapper | |
from transformers import pipeline as hf_pipeline | |
import torch | |
import litellm | |
from tqdm import tqdm | |
import subprocess | |
# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132 | |
# subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) | |
# pipeline = hf_pipeline( | |
# "text-generation", | |
# model="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
# model_kwargs={"torch_dtype": 'bfloat16'}, | |
# device_map="auto", | |
# ) | |
class ModelPrediction: | |
def __init__(self): | |
self.model_name2pred_func = { | |
"gpt-3.5": self._init_model_prediction("gpt-3.5"), | |
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"), | |
"llama-70": self._init_model_prediction("llama-70"), | |
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction( | |
"DeepSeek-R1-Distill-Llama-70B" | |
), | |
"llama-8": self._init_model_prediction("llama-8"), | |
} | |
self._model_name = None | |
self._pipeline = None | |
self.base_prompt= ( | |
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n" | |
" Question\n" | |
"{question}\n" | |
"Database Schema\n" | |
"{db_schema}\n" | |
) | |
self.base_prompt_QA= ( | |
"Return the answer of the following question based on the provided database." | |
" Return your answer as the result of a query executed over the database." | |
" Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n" | |
"Return the answer in answer tag as <answer> </answer>" | |
" Question\n" | |
"{question}\n" | |
"Database Schema\n" | |
"{db_schema}\n" | |
) | |
def _extract_answer_from_pred(pred: str) -> str: | |
# extract with regex everything is between <answer> and </answer> | |
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL) | |
if matches: | |
return matches[-1].replace("```", "").replace("sql", "").strip() | |
else: | |
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL) | |
return matches[-1].strip() if matches else pred | |
def make_prediction(self, question, db_schema, model_name, prompt=None, task='SP'): | |
if model_name not in self.model_name2pred_func: | |
raise ValueError( | |
"Model not supported", | |
"supported models are", | |
self.model_name2pred_func.keys(), | |
) | |
if task == 'SP': | |
prompt = prompt or self.base_prompt | |
else: | |
prompt = prompt or self.base_prompt_QA | |
prompt = prompt.format(question=question, db_schema=db_schema) | |
start_time = time.time() | |
prediction = self.model_name2pred_func[model_name](prompt) | |
end_time = time.time() | |
prediction["response_parsed"] = self._extract_answer_from_pred( | |
prediction["response"] | |
) | |
prediction['time'] = end_time - start_time | |
return prediction | |
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]: | |
response = litellm.completion( | |
model=model_name, | |
messages=[{"role": "user", "content": prompt}], | |
num_retries=2, | |
) | |
response_text = response["choices"][0]["message"]["content"] | |
return { | |
"response": response_text, | |
"cost": response._hidden_params["response_cost"], | |
} | |
def _init_model_prediction(self, model_name): | |
predict_fun = self.predict_with_api | |
if "gpt-3.5" in model_name: | |
model_name = "openai/gpt-3.5-turbo-0125" | |
elif "gpt-4o-mini" in model_name: | |
model_name = "openai/gpt-4o-mini-2024-07-18" | |
elif "DeepSeek-R1-Distill-Llama-70B" in model_name: | |
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B" | |
elif "llama-8" in model_name: | |
model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" | |
elif "llama-70" in model_name: | |
model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo" | |
else: | |
raise ValueError("Model forbidden") | |
return partial(predict_fun, model_name=model_name) | |