|
from transformers import LlamaForCausalLM, AutoTokenizer |
|
from transformers import BitsAndBytesConfig, LlamaForCausalLM, LlamaTokenizer |
|
from peft import PeftModel, PeftConfig |
|
import torch |
|
|
|
class EndpointHandler: |
|
|
|
def __init__(self, model_path="."): |
|
self.model = LlamaForCausalLM.from_pretrained(model_path, |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type='nf4' |
|
) |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, |
|
eos_token = "<|eot_id|>") |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
def __call__(self, request_data): |
|
|
|
prompt = request_data["prompt"] |
|
|
|
chat = [ |
|
{"role": "system", "content": "You are a helpful assistant"}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt") |
|
|
|
output = self.model.generate(input_ids, max_length=400) |
|
|
|
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=False) |
|
generated_text = generated_text.replace(prompt,'').replace('<|begin_of_text|>', '').strip() |
|
|
|
return {"response": generated_text} |