cafierom's picture
Update handler.py
149288b verified
from typing import Dict, List, Any
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
class EndpointHandler():
def __init__(self, path=""):
model_name = "microsoft/Phi-3.5-mini-instruct"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16).to(device)
model.load_adapter("cafierom/Phi-3.5-mini-instruct-Gen-TF-Mottos")
self.pipeline = pipeline("text-generation",model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs",data)
#inputs.to(device)
prediction = self.pipeline(inputs)
return prediction