qatch-demo / prediction.py
simone-papicchio's picture
chore: removed unused pipeline
947409c
from functools import partial
import os
import re
import time
from xml.parsers.expat import model
# https://discuss.huggingface.co/t/issues-with-sadtalker-zerogpu-spaces-inquiry-about-community-grant/110625/10
if os.environ.get("SPACES_ZERO_GPU") is not None:
import spaces
else:
class spaces:
@staticmethod
def GPU(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
from transformers import pipeline as hf_pipeline
import torch
import litellm
from tqdm import tqdm
import subprocess
# https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/132
# subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
# pipeline = hf_pipeline(
# "text-generation",
# model="meta-llama/Meta-Llama-3.1-8B-Instruct",
# model_kwargs={"torch_dtype": 'bfloat16'},
# device_map="auto",
# )
class ModelPrediction:
def __init__(self):
self.model_name2pred_func = {
"gpt-3.5": self._init_model_prediction("gpt-3.5"),
"gpt-4o-mini": self._init_model_prediction("gpt-4o-mini"),
"llama-70": self._init_model_prediction("llama-70"),
"DeepSeek-R1-Distill-Llama-70B": self._init_model_prediction(
"DeepSeek-R1-Distill-Llama-70B"
),
"llama-8": self._init_model_prediction("llama-8"),
}
self._model_name = None
self._pipeline = None
self.base_prompt= (
"Translate the following question in SQL code to be executed over the database to fetch the answer. Return the sql code in ```sql ```\n"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
self.base_prompt_QA= (
"Return the answer of the following question based on the provided database."
" Return your answer as the result of a query executed over the database."
" Namely, as a list of list where the first list represent the tuples and the second list the values in that tuple.\n"
"Return the answer in answer tag as <answer> </answer>"
" Question\n"
"{question}\n"
"Database Schema\n"
"{db_schema}\n"
)
@staticmethod
def _extract_answer_from_pred(pred: str) -> str:
# extract with regex everything is between <answer> and </answer>
matches = re.findall(r"<answer>(.*?)</answer>", pred, re.DOTALL)
if matches:
return matches[-1].replace("```", "").replace("sql", "").strip()
else:
matches = re.findall(r"```sql(.*?)```", pred, re.DOTALL)
return matches[-1].strip() if matches else pred
def make_prediction(self, question, db_schema, model_name, prompt=None, task='SP'):
if model_name not in self.model_name2pred_func:
raise ValueError(
"Model not supported",
"supported models are",
self.model_name2pred_func.keys(),
)
if task == 'SP':
prompt = prompt or self.base_prompt
else:
prompt = prompt or self.base_prompt_QA
prompt = prompt.format(question=question, db_schema=db_schema)
start_time = time.time()
prediction = self.model_name2pred_func[model_name](prompt)
end_time = time.time()
prediction["response_parsed"] = self._extract_answer_from_pred(
prediction["response"]
)
prediction['time'] = end_time - start_time
return prediction
def predict_with_api(self, prompt, model_name): # -> dict[str, Any | float]:
response = litellm.completion(
model=model_name,
messages=[{"role": "user", "content": prompt}],
num_retries=2,
)
response_text = response["choices"][0]["message"]["content"]
return {
"response": response_text,
"cost": response._hidden_params["response_cost"],
}
def _init_model_prediction(self, model_name):
predict_fun = self.predict_with_api
if "gpt-3.5" in model_name:
model_name = "openai/gpt-3.5-turbo-0125"
elif "gpt-4o-mini" in model_name:
model_name = "openai/gpt-4o-mini-2024-07-18"
elif "DeepSeek-R1-Distill-Llama-70B" in model_name:
model_name = "together_ai/deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
elif "llama-8" in model_name:
model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
elif "llama-70" in model_name:
model_name = "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
else:
raise ValueError("Model forbidden")
return partial(predict_fun, model_name=model_name)