GaniduA commited on
Commit
aeced00
·
verified ·
1 Parent(s): 0aa3756

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -0
handler.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
2
+
3
+ class EndpointHandler:
4
+ def __init__(self, model_path: str, task="text-generation"):
5
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
6
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
7
+ self.pipe = pipeline(task=task, model=self.model, tokenizer=self.tokenizer)
8
+
9
+ def __call__(self, inputs: dict) -> dict:
10
+ prompt = inputs.get("inputs", "")
11
+ params = inputs.get("parameters", {})
12
+ outputs = self.pipe(prompt, **params)
13
+ return {"generated_text": outputs[0]["generated_text"]}