PhoBERTNode / app.py
VietCat's picture
split text into trunk to fit the token length of 256
b54a7c5
raw
history blame
1.9 kB
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, TFAutoModel
import tensorflow as tf
import numpy as np
app = Flask(__name__)
# Load PhoBERT (TensorFlow version)
MODEL_NAME = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = TFAutoModel.from_pretrained(MODEL_NAME)
MAX_LEN = 256
STRIDE = 128
def split_text_into_chunks(text):
tokens = tokenizer.encode(text, add_special_tokens=True)
chunks = []
for i in range(0, len(tokens), STRIDE):
chunk = tokens[i:i + MAX_LEN]
if len(chunk) < MAX_LEN:
chunk += [tokenizer.pad_token_id] * (MAX_LEN - len(chunk))
chunks.append(chunk)
if i + MAX_LEN >= len(tokens):
break
return chunks
def embed_text(text):
chunks = split_text_into_chunks(text)
embeddings = []
for chunk in chunks:
input_ids = tf.constant([chunk])
attention_mask = tf.cast(input_ids != tokenizer.pad_token_id, tf.int32)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state
mask = tf.cast(tf.expand_dims(attention_mask, -1), tf.float32)
summed = tf.reduce_sum(hidden_states * mask, axis=1)
count = tf.reduce_sum(mask, axis=1)
mean_pooled = summed / count
embeddings.append(mean_pooled.numpy()[0])
final_embedding = np.mean(embeddings, axis=0)
return final_embedding.tolist()
@app.route('/embed', methods=['POST'])
def embed():
data = request.get_json()
text = data.get('text', '')
if not text:
return jsonify({"error": "No text provided"}), 400
embedding = embed_text(text)
return jsonify({"embedding": embedding})
@app.route('/', methods=['GET'])
def index():
return "PhoBERT vector API is running!"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)