File size: 2,316 Bytes
0df58e7
f24ab5f
4591cc5
fa938a5
b83aafc
0df58e7
 
 
9e99cfd
817f818
 
0df58e7
fa938a5
 
 
 
 
 
 
0df58e7
 
 
d2d5c5b
0df58e7
 
 
 
950facf
bbfcf64
0df58e7
 
 
bbfcf64
b83aafc
fa938a5
b83aafc
 
 
fa938a5
b83aafc
0df58e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa938a5
 
 
 
 
 
 
 
 
e847bd1
 
 
 
 
 
 
817f818
 
 
 
 
 
 
 
 
c5a9400
b83aafc
 
 
 
 
 
 
 
11f0536
 
5b79173
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from scipy.special import softmax


# Define the function for sentiment analysis




# # Setting the page configurations
# st.set_page_config(
#     page_title="Sentiment Analysis App",
#     page_icon=":smile:",
#     layout="wide",
#     initial_sidebar_state="auto",
# )

# Add description and title
st.write("""
# Sentiment Analysis App
""")


# Add image
image = st.image("images.png", width=200)


# Get user input
text = st.text_input("Type here:")
button = st.button('Analyze')
# d = {

#     0:'Negative',
#     1:'Neutral',
#     2: 'Positive'
     
# }
# Define the CSS style for the app
st.markdown(
"""
<style>
body {
    background-color: #f5f5f5;
}
h1 {
    color: #4e79a7;
}
</style>
""",
unsafe_allow_html=True
)

# Show sentiment output
# if text:
#     sentiment, score = predict_sentiment(text)
#     if sentiment == "Positive":
#         st.success(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!")
#     elif sentiment == "Negative":
#         st.error(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!")
#     else:
#         st.warning(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!")

def preprocess(text):
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    return " ".join(new_text)

@st.cache_resource()
def get_model():
   # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased")
    model = AutoModelForSequenceClassification.from_pretrained("MrDdz/mytuned_test_trainer-base-cased1")
    return tokenizer,model
tokenizer, model = get_model()

if text and button:
    text_sample = tokenizer(text, padding = 'max_length',return_tensors = 'pt')
    # print(text_sample)
    output = model(**text_sample)
    scores_ = output[0][0].detach().numpy()
    scores_ = softmax(scores_)

    labels = ['Negative','Neutral','Positive']
    scores = {l:float(s) for (l,s) in zip(labels,scores_)}
    # st.write("Logits: ",output.logits)
    # y_pred = np.argmax(output.logits.detach().numpy(),axis =1)
    st.write("Prediction :",scores)