import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import numpy as np from scipy.special import softmax # Define the function for sentiment analysis # # Setting the page configurations # st.set_page_config( # page_title="Sentiment Analysis App", # page_icon=":smile:", # layout="wide", # initial_sidebar_state="auto", # ) # Add description and title st.write(""" # Sentiment Analysis App """) # Add image image = st.image("images.png", width=200) # Get user input text = st.text_input("Type here:") button = st.button('Analyze') # d = { # 0:'Negative', # 1:'Neutral', # 2: 'Positive' # } # Define the CSS style for the app st.markdown( """ """, unsafe_allow_html=True ) # Show sentiment output # if text: # sentiment, score = predict_sentiment(text) # if sentiment == "Positive": # st.success(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!") # elif sentiment == "Negative": # st.error(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!") # else: # st.warning(f"The sentiment is {sentiment} with a score of {score*100:.2f}%!") def preprocess(text): new_text = [] for t in text.split(" "): t = '@user' if t.startswith('@') and len(t) > 1 else t t = 'http' if t.startswith('http') else t new_text.append(t) return " ".join(new_text) @st.cache_resource() def get_model(): # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("xlnet-base-cased") model = AutoModelForSequenceClassification.from_pretrained("MrDdz/mytuned_test_trainer-base-cased1") return tokenizer,model tokenizer, model = get_model() if text and button: text_sample = tokenizer(text, padding = 'max_length',return_tensors = 'pt') # print(text_sample) output = model(**text_sample) scores_ = output[0][0].detach().numpy() scores_ = softmax(scores_) labels = ['Negative','Neutral','Positive'] scores = {l:float(s) for (l,s) in zip(labels,scores_)} # st.write("Logits: ",output.logits) # y_pred = np.argmax(output.logits.detach().numpy(),axis =1) st.write("Prediction :",labels)