soro / handler.py
zaddyzaddy's picture
Upload 2 files
8fe825e verified
raw
history blame
No virus
2.88 kB
from unsloth import FastLanguageModel
from typing import Dict, List, Any
import torch
class EndpointHandler:
def __init__(self, path=""):
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_name=path,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
self.alpaca_prompt = """
### Instruction:
{}
### Input:
{}
### Response:
"""
self.EOS_TOKEN = self.tokenizer.eos_token
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
data = data.pop("inputs", data)
input_text = data.get("input_text", "")
lex_diversity = data.get("lex_diversity", 80)
order_diversity = data.get("order_diversity", 20)
repetition_penalty = data.get("repetition_penalty", 1.0)
use_cache = data.get("use_cache", False)
max_length = data.get("max_length", 128)
prediction = self.paraphrase(
input_text,
lex_diversity,
order_diversity,
repetition_penalty=repetition_penalty,
use_cache=use_cache,
max_length=max_length
)
prediction = {'prediction': prediction}
return prediction
def paraphrase(self, input_text, lex_diversity, order_diversity, repetition_penalty, use_cache, max_length, **kwargs):
FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
inputs = self.tokenizer(
[
self.alpaca_prompt.format(
"You are an AI assistant, capable of paraphrasing any text to a human-like version of the text. Human writing often exhibits bursts and lulls, with a mix of long and short sentences", # instruction
f"lexical = {lex_diversity}, order = {order_diversity} {input_text}",
"", # output - leave this blank for generation!
)
], return_tensors="pt").to("cuda")
outputs = self.model.generate(**inputs, max_new_tokens=max_length, use_cache=False, repetition_penalty=repetition_penalty)
output_text = self.tokenizer.batch_decode(outputs)
return output_text