File size: 1,562 Bytes
562f239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import LlamaForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig, LlamaForCausalLM, LlamaTokenizer
from peft import PeftModel, PeftConfig
import torch

class EndpointHandler:

    def __init__(self, model_path="."):
        self.model = LlamaForCausalLM.from_pretrained(model_path,
                    quantization_config=BitsAndBytesConfig(
                          load_in_4bit=True,
                          bnb_4bit_compute_dtype=torch.bfloat16,
                          bnb_4bit_use_double_quant=True,
                          bnb_4bit_quant_type='nf4'
                      )
                   )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path,
                                                        eos_token = "<|eot_id|>")
        self.tokenizer.pad_token = self.tokenizer.eos_token


    def __call__(self, request_data):

        prompt = request_data["prompt"]

        chat = [
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": prompt}
        ]

        prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")

        output = self.model.generate(input_ids, max_length=400)

        generated_text = self.tokenizer.decode(output[0], skip_special_tokens=False)
        generated_text = generated_text.replace(prompt,'').replace('<|begin_of_text|>', '').strip()
        
        return {"response": generated_text}