File size: 2,880 Bytes
8fe825e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from unsloth import FastLanguageModel
from typing import Dict, List, Any
import torch

class EndpointHandler:
    def __init__(self, path=""):
        max_seq_length = 2048  # Choose any! We auto support RoPE Scaling internally!
        dtype = None  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
        load_in_4bit = True  # Use 4bit quantization to reduce memory usage. Can be False.

        self.model, self.tokenizer = FastLanguageModel.from_pretrained(
            model_name=path,
            max_seq_length=max_seq_length,
            dtype=dtype,
            load_in_4bit=load_in_4bit,
            # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
        )

        self.alpaca_prompt = """
        ### Instruction:
        {}
        
        ### Input:
        {}
        
        ### Response:
        """
            
        self.EOS_TOKEN = self.tokenizer.eos_token

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        data args:
            inputs (:obj: `str`)
            date (:obj: `str`)
        Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        
        data = data.pop("inputs", data)
        input_text = data.get("input_text", "")
        lex_diversity = data.get("lex_diversity", 80)
        order_diversity = data.get("order_diversity", 20)
        repetition_penalty = data.get("repetition_penalty", 1.0)
        use_cache = data.get("use_cache", False)
        max_length = data.get("max_length", 128)
        
        prediction = self.paraphrase(
            input_text, 
            lex_diversity, 
            order_diversity, 
            repetition_penalty=repetition_penalty, 
            use_cache=use_cache, 
            max_length=max_length
        )

        prediction = {'prediction': prediction}
        return prediction

    def paraphrase(self, input_text, lex_diversity, order_diversity, repetition_penalty, use_cache, max_length, **kwargs):
        FastLanguageModel.for_inference(self.model)  # Enable native 2x faster inference
        inputs = self.tokenizer(
            [
                self.alpaca_prompt.format(
                    "You are an AI assistant, capable of paraphrasing any text to a human-like version of the text. Human writing often exhibits bursts and lulls, with a mix of long and short sentences",  # instruction
                    f"lexical = {lex_diversity}, order = {order_diversity} {input_text}",
                    "",  # output - leave this blank for generation!
                )
            ], return_tensors="pt").to("cuda")
        
        outputs = self.model.generate(**inputs, max_new_tokens=max_length, use_cache=False, repetition_penalty=repetition_penalty)
        output_text = self.tokenizer.batch_decode(outputs)
        return output_text