flan-t5 / main.py
vasilee's picture
Update main.py
902ad9c
raw
history blame
2.89 kB
from flask import Flask, request, jsonify
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from ctranslate2 import Translator
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# text-ada replacement
embeddingTokenizer = AutoTokenizer.from_pretrained(
'./multilingual-e5-base')
embeddingModel = AutoModel.from_pretrained('./multilingual-e5-base')
# chatGpt replacement
inferenceTokenizer = AutoTokenizer.from_pretrained(
"./ct2fast-flan-alpaca-xl")
inferenceTranslator = Translator(
"./ct2fast-flan-alpaca-xl", compute_type="int8", device="cpu")
app = Flask(__name__)
@app.route('/text-embedding', methods=['POST'])
def text_embedding():
# Get the JSON data from the request
data = request.get_json()
input = data["input"]
# Process the input data
batch_dict = embeddingTokenizer([input], max_length=512,
padding=True, truncation=True, return_tensors='pt')
outputs = embeddingModel(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state,
batch_dict['attention_mask'])
token_ids = batch_dict["input_ids"][0].tolist()
# Create a JSON response
response = {
'embedding': embeddings[0].tolist()
}
return jsonify(response)
@app.route('/inference', methods=['POST'])
def inference():
# Get the JSON data from the request
data = request.get_json()
input_text = data["input"]
max_length = 256
try:
max_length = int(data["max_length"])
max_length = min(1024, max_length)
except:
pass
input_tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
results = inferenceTranslator.translate_batch(
[input_tokens], max_input_length=0, max_decoding_length=max_length, num_hypotheses=1, repetition_penalty=1.3, sampling_topk=30, sampling_temperature=1.1, use_vmap=True)
output_tokens = results[0].hypotheses[0]
output_text = inferenceTokenizer.decode(
inferenceTokenizer.convert_tokens_to_ids(output_tokens))
# Create a JSON response
response = {
'generated_text': output_text
}
return jsonify(response)
@app.route('/tokens-count', methods=['POST'])
def tokens_count():
# Get the JSON data from the request
data = request.get_json()
input_text = data["input"]
tokens = inferenceTokenizer.convert_ids_to_tokens(
inferenceTokenizer.encode(input_text))
# Create a JSON response
response = {
'tokens': tokens,
'total': len(tokens)
}
return jsonify(response)
if __name__ == '__main__':
app.run()