lfernandopg's picture
Update app.py
685da0f
import streamlit as st
import numpy as np
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
import torch
import tensorflow as tf
@st.cache(allow_output_mutation=True)
def get_model():
tokenizer = DistilBertTokenizer.from_pretrained('lfernandopg/Proyecto-Transformers')
model = TFDistilBertForSequenceClassification.from_pretrained("lfernandopg/Proyecto-Transformers")
return tokenizer,model
tokenizer,model = get_model()
user_input = st.text_area('Enter Text to Analyze')
button = st.button("Analyze")
d = {
0 : 'Accountant',
1 : 'Actuary',
2 : 'Biologist',
3 : 'Chemist',
4 : 'Civil engineer',
5 : 'Computer programmer',
6 : 'Data scientist',
7 : 'Database administrator',
8 : 'Dentist',
9 : 'Economist',
10 : 'Environmental engineer',
11 : 'Financial analyst',
12 : 'IT manager',
13 : 'Mathematician',
14 : 'Mechanical engineer',
15 : 'Physician assistant',
16 : 'Psychologist',
17 : 'Statistician',
18 : 'Systems analyst',
19 : 'Technical writer ',
20 : 'Web developer '
}
if user_input and button :
predict_input = tokenizer.encode(user_input,
truncation=True,
padding=True,
return_tensors="tf")
output = model(predict_input)[0]
prediction_value = tf.argmax(output, axis=1).numpy()[0]
st.write("Logits: ",prediction_value)
st.write("Prediction: ",d[prediction_value])