zaddyzaddy commited on
Commit
8fe825e
1 Parent(s): 01f6258

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +73 -0
  2. requirements.txt +10 -0
handler.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unsloth import FastLanguageModel
2
+ from typing import Dict, List, Any
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
8
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
9
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
10
+
11
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
12
+ model_name=path,
13
+ max_seq_length=max_seq_length,
14
+ dtype=dtype,
15
+ load_in_4bit=load_in_4bit,
16
+ # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
17
+ )
18
+
19
+ self.alpaca_prompt = """
20
+ ### Instruction:
21
+ {}
22
+
23
+ ### Input:
24
+ {}
25
+
26
+ ### Response:
27
+ """
28
+
29
+ self.EOS_TOKEN = self.tokenizer.eos_token
30
+
31
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
32
+ """
33
+ data args:
34
+ inputs (:obj: `str`)
35
+ date (:obj: `str`)
36
+ Return:
37
+ A :obj:`list` | `dict`: will be serialized and returned
38
+ """
39
+
40
+ data = data.pop("inputs", data)
41
+ input_text = data.get("input_text", "")
42
+ lex_diversity = data.get("lex_diversity", 80)
43
+ order_diversity = data.get("order_diversity", 20)
44
+ repetition_penalty = data.get("repetition_penalty", 1.0)
45
+ use_cache = data.get("use_cache", False)
46
+ max_length = data.get("max_length", 128)
47
+
48
+ prediction = self.paraphrase(
49
+ input_text,
50
+ lex_diversity,
51
+ order_diversity,
52
+ repetition_penalty=repetition_penalty,
53
+ use_cache=use_cache,
54
+ max_length=max_length
55
+ )
56
+
57
+ prediction = {'prediction': prediction}
58
+ return prediction
59
+
60
+ def paraphrase(self, input_text, lex_diversity, order_diversity, repetition_penalty, use_cache, max_length, **kwargs):
61
+ FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
62
+ inputs = self.tokenizer(
63
+ [
64
+ self.alpaca_prompt.format(
65
+ "You are an AI assistant, capable of paraphrasing any text to a human-like version of the text. Human writing often exhibits bursts and lulls, with a mix of long and short sentences", # instruction
66
+ f"lexical = {lex_diversity}, order = {order_diversity} {input_text}",
67
+ "", # output - leave this blank for generation!
68
+ )
69
+ ], return_tensors="pt").to("cuda")
70
+
71
+ outputs = self.model.generate(**inputs, max_new_tokens=max_length, use_cache=False, repetition_penalty=repetition_penalty)
72
+ output_text = self.tokenizer.batch_decode(outputs)
73
+ return output_text
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
2
+ packaging
3
+ ninja
4
+ einops
5
+ flash-attn
6
+ xformers
7
+ trl
8
+ peft
9
+ accelerate
10
+ bitsandbytes