File size: 2,893 Bytes
902ad9c
 
 
 
e8f71ec
 
902ad9c
 
 
 
 
e8f71ec
 
902ad9c
 
 
 
e8f71ec
902ad9c
 
 
 
 
e8f71ec
902ad9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from flask import Flask, request, jsonify
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator


def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(
        ~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
    './multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')

# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
    "./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
    "./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")


app = Flask(__name__)


@app.route('/text-embedding', methods=['POST'])
def text_embedding():
    # Get the JSON data from the request
    data = request.get_json()
    input = data["input"]

    # Process the input data
    batch_dict = embeddingTokenizer([input], max_length=512,
                                    padding=True, truncation=True, return_tensors='pt')
    outputs = embeddingModel(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state,
                              batch_dict['attention_mask'])
    token_ids = batch_dict["input_ids"][0].tolist()

    # Create a JSON response
    response = {
        'embedding': embeddings[0].tolist()
    }

    return jsonify(response)


@app.route('/inference', methods=['POST'])
def inference():
    # Get the JSON data from the request
    data = request.get_json()
    input_text = data["input"]
    max_length = 256
    try:
        max_length = int(data["max_length"])
        max_length = min(1024, max_length)
    except:
        pass

    input_tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    results = inferenceTranslator.translate_batch(
        [input_tokens], max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=30, sampling_temperature=1.1, use_vmap=True)

    output_tokens = results[0].hypotheses[0]
    output_text = inferenceTokenizer.decode(
        inferenceTokenizer.convert_tokens_to_ids(output_tokens))

    # Create a JSON response
    response = {
        'generated_text': output_text
    }

    return jsonify(response)


@app.route('/tokens-count', methods=['POST'])
def tokens_count():
    # Get the JSON data from the request
    data = request.get_json()
    input_text = data["input"]

    tokens = inferenceTokenizer.convert_ids_to_tokens(
        inferenceTokenizer.encode(input_text))

    # Create a JSON response
    response = {
        'tokens': tokens,
        'total': len(tokens)
    }

    return jsonify(response)


if __name__ == '__main__':
    app.run()