Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ from huggingface_hub import hf_hub_download
|
|
| 9 |
import torch
|
| 10 |
import pickle
|
| 11 |
import numpy as np
|
| 12 |
-
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 13 |
|
| 14 |
# Load models and tokenizers
|
| 15 |
models = {
|
|
@@ -76,13 +75,12 @@ def predict_with_bert_multilingual(text):
|
|
| 76 |
def predict_with_tinybert(text):
|
| 77 |
tokenizer = models["TinyBERT"]["tokenizer"]
|
| 78 |
model = models["TinyBERT"]["model"]
|
| 79 |
-
encodings = tokenizer([text], padding=True, truncation=True, max_length=
|
| 80 |
with torch.no_grad():
|
| 81 |
outputs = model(**encodings)
|
| 82 |
logits = outputs.logits
|
| 83 |
predictions = logits.argmax(axis=-1).cpu().numpy()
|
| 84 |
-
return int(predictions[0])
|
| 85 |
-
|
| 86 |
|
| 87 |
# Unified function for sentiment analysis and statistics
|
| 88 |
def analyze_sentiment_and_statistics(text):
|
|
@@ -95,22 +93,28 @@ def analyze_sentiment_and_statistics(text):
|
|
| 95 |
|
| 96 |
# Calculate statistics
|
| 97 |
scores = list(results.values())
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
return results, statistics
|
| 108 |
|
| 109 |
# Gradio Interface
|
| 110 |
with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
|
| 111 |
gr.Markdown("# Sentiment Analysis App")
|
| 112 |
gr.Markdown(
|
| 113 |
-
"This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides
|
| 114 |
)
|
| 115 |
|
| 116 |
with gr.Row():
|
|
@@ -150,7 +154,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
| 150 |
with gr.Column():
|
| 151 |
distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
|
| 152 |
log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
|
| 153 |
-
bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
|
| 154 |
tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
|
| 155 |
|
| 156 |
with gr.Column():
|
|
@@ -159,13 +163,22 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
| 159 |
# Button to analyze sentiment and show statistics
|
| 160 |
def process_input_and_analyze(text_input):
|
| 161 |
results, statistics = analyze_sentiment_and_statistics(text_input)
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
analyze_button.click(
|
| 171 |
process_input_and_analyze,
|
|
@@ -173,7 +186,5 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
|
|
| 173 |
outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
|
| 174 |
)
|
| 175 |
|
| 176 |
-
|
| 177 |
-
|
| 178 |
# Launch the app
|
| 179 |
demo.launch()
|
|
|
|
| 9 |
import torch
|
| 10 |
import pickle
|
| 11 |
import numpy as np
|
|
|
|
| 12 |
|
| 13 |
# Load models and tokenizers
|
| 14 |
models = {
|
|
|
|
| 75 |
def predict_with_tinybert(text):
|
| 76 |
tokenizer = models["TinyBERT"]["tokenizer"]
|
| 77 |
model = models["TinyBERT"]["model"]
|
| 78 |
+
encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
|
| 79 |
with torch.no_grad():
|
| 80 |
outputs = model(**encodings)
|
| 81 |
logits = outputs.logits
|
| 82 |
predictions = logits.argmax(axis=-1).cpu().numpy()
|
| 83 |
+
return int(predictions[0] + 1)
|
|
|
|
| 84 |
|
| 85 |
# Unified function for sentiment analysis and statistics
|
| 86 |
def analyze_sentiment_and_statistics(text):
|
|
|
|
| 93 |
|
| 94 |
# Calculate statistics
|
| 95 |
scores = list(results.values())
|
| 96 |
+
if all(score == scores[0] for score in scores): # Check if all predictions are the same
|
| 97 |
+
statistics = {
|
| 98 |
+
"Message": "All models predict the same score.",
|
| 99 |
+
"Average Score": f"{scores[0]:.2f}",
|
| 100 |
+
}
|
| 101 |
+
else:
|
| 102 |
+
min_score_model = min(results, key=results.get)
|
| 103 |
+
max_score_model = max(results, key=results.get)
|
| 104 |
+
average_score = np.mean(scores)
|
| 105 |
+
|
| 106 |
+
statistics = {
|
| 107 |
+
"Lowest Score": f"{results[min_score_model]} (Model: {min_score_model})",
|
| 108 |
+
"Highest Score": f"{results[max_score_model]} (Model: {max_score_model})",
|
| 109 |
+
"Average Score": f"{average_score:.2f}",
|
| 110 |
+
}
|
| 111 |
return results, statistics
|
| 112 |
|
| 113 |
# Gradio Interface
|
| 114 |
with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
|
| 115 |
gr.Markdown("# Sentiment Analysis App")
|
| 116 |
gr.Markdown(
|
| 117 |
+
"This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides basic statistics."
|
| 118 |
)
|
| 119 |
|
| 120 |
with gr.Row():
|
|
|
|
| 154 |
with gr.Column():
|
| 155 |
distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
|
| 156 |
log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
|
| 157 |
+
bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
|
| 158 |
tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
|
| 159 |
|
| 160 |
with gr.Column():
|
|
|
|
| 163 |
# Button to analyze sentiment and show statistics
|
| 164 |
def process_input_and_analyze(text_input):
|
| 165 |
results, statistics = analyze_sentiment_and_statistics(text_input)
|
| 166 |
+
if "Message" in statistics: # All models predicted the same score
|
| 167 |
+
return (
|
| 168 |
+
f"{results['DistilBERT']}",
|
| 169 |
+
f"{results['Logistic Regression']}",
|
| 170 |
+
f"{results['BERT Multilingual (NLP Town)']}",
|
| 171 |
+
f"{results['TinyBERT']}",
|
| 172 |
+
f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
|
| 173 |
+
)
|
| 174 |
+
else: # Min and Max scores are present
|
| 175 |
+
return (
|
| 176 |
+
f"{results['DistilBERT']}",
|
| 177 |
+
f"{results['Logistic Regression']}",
|
| 178 |
+
f"{results['BERT Multilingual (NLP Town)']}",
|
| 179 |
+
f"{results['TinyBERT']}",
|
| 180 |
+
f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
|
| 181 |
+
)
|
| 182 |
|
| 183 |
analyze_button.click(
|
| 184 |
process_input_and_analyze,
|
|
|
|
| 186 |
outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
|
| 187 |
)
|
| 188 |
|
|
|
|
|
|
|
| 189 |
# Launch the app
|
| 190 |
demo.launch()
|