nhull commited on
Commit
cc8348d
·
verified ·
1 Parent(s): c82d132

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ DistilBertTokenizerFast,
4
+ DistilBertForSequenceClassification,
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ )
8
+ from huggingface_hub import hf_hub_download
9
+ import torch
10
+ import pickle
11
+ import numpy as np
12
+
13
+ # Load models and tokenizers
14
+ models = {
15
+ "DistilBERT": {
16
+ "tokenizer": DistilBertTokenizerFast.from_pretrained("nhull/distilbert-sentiment-model"),
17
+ "model": DistilBertForSequenceClassification.from_pretrained("nhull/distilbert-sentiment-model"),
18
+ },
19
+ "Logistic Regression": {}, # Placeholder for logistic regression
20
+ "BERT Multilingual (NLP Town)": {
21
+ "tokenizer": AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
22
+ "model": AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment"),
23
+ }
24
+ }
25
+
26
+ # Load logistic regression model and vectorizer
27
+ logistic_regression_repo = "nhull/logistic-regression-model"
28
+
29
+ # Download and load logistic regression model
30
+ log_reg_model_path = hf_hub_download(repo_id=logistic_regression_repo, filename="logistic_regression_model.pkl")
31
+ with open(log_reg_model_path, "rb") as model_file:
32
+ log_reg_model = pickle.load(model_file)
33
+
34
+ # Download and load TF-IDF vectorizer
35
+ vectorizer_path = hf_hub_download(repo_id=logistic_regression_repo, filename="tfidf_vectorizer.pkl")
36
+ with open(vectorizer_path, "rb") as vectorizer_file:
37
+ vectorizer = pickle.load(vectorizer_file)
38
+
39
+ # Move HuggingFace models to device (if GPU is available)
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ for model_data in models.values():
42
+ if "model" in model_data:
43
+ model_data["model"].to(device)
44
+
45
+ # Functions for prediction
46
+ def predict_with_distilbert(text):
47
+ tokenizer = models["DistilBERT"]["tokenizer"]
48
+ model = models["DistilBERT"]["model"]
49
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
50
+ with torch.no_grad():
51
+ outputs = model(**encodings)
52
+ logits = outputs.logits
53
+ predictions = logits.argmax(axis=-1).cpu().numpy()
54
+ return int(predictions[0] + 1)
55
+
56
+ def predict_with_logistic_regression(text):
57
+ transformed_text = vectorizer.transform([text])
58
+ predictions = log_reg_model.predict(transformed_text)
59
+ return int(predictions[0])
60
+
61
+ def predict_with_bert_multilingual(text):
62
+ tokenizer = models["BERT Multilingual (NLP Town)"]["tokenizer"]
63
+ model = models["BERT Multilingual (NLP Town)"]["model"]
64
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
65
+ with torch.no_grad():
66
+ outputs = model(**encodings)
67
+ logits = outputs.logits
68
+ predictions = logits.argmax(axis=-1).cpu().numpy()
69
+ return int(predictions[0] + 1)
70
+
71
+ # Unified function for sentiment analysis and statistics
72
+ def analyze_sentiment_and_statistics(text):
73
+ results = {
74
+ "DistilBERT": predict_with_distilbert(text),
75
+ "Logistic Regression": predict_with_logistic_regression(text),
76
+ "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
77
+ }
78
+
79
+ # Calculate statistics
80
+ scores = list(results.values())
81
+ min_score_model = min(results, key=results.get)
82
+ max_score_model = max(results, key=results.get)
83
+ average_score = np.mean(scores)
84
+
85
+ statistics = {
86
+ "Lowest Score": f"{results[min_score_model]} (Model: {min_score_model})",
87
+ "Highest Score": f"{results[max_score_model]} (Model: {max_score_model})",
88
+ "Average Score": f"{average_score:.2f}",
89
+ }
90
+ return results, statistics
91
+
92
+ # Gradio Interface
93
+ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding: 20px; }") as demo:
94
+ gr.Markdown("# Sentiment Analysis App")
95
+ gr.Markdown(
96
+ "This app predicts the sentiment of the input text on a scale from 1 to 5 using multiple models and provides detailed statistics."
97
+ )
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ text_input = gr.Textbox(
102
+ label="Enter your text here:",
103
+ lines=3,
104
+ placeholder="Type your hotel/restaurant review here..."
105
+ )
106
+ sample_reviews = [
107
+ "The hotel was fantastic! Clean rooms and excellent service.",
108
+ "The food was horrible, and the staff was rude.",
109
+ "Amazing experience overall. Highly recommend!",
110
+ "It was okay, not great but not terrible either.",
111
+ "Terrible! The room was dirty, and the service was non-existent."
112
+ ]
113
+ sample_dropdown = gr.Dropdown(
114
+ choices=sample_reviews,
115
+ label="Or select a sample review:",
116
+ interactive=True
117
+ )
118
+
119
+ # Sync dropdown with text input
120
+ def update_textbox(selected_sample):
121
+ return selected_sample
122
+
123
+ sample_dropdown.change(
124
+ update_textbox,
125
+ inputs=[sample_dropdown],
126
+ outputs=[text_input]
127
+ )
128
+
129
+ with gr.Column():
130
+ analyze_button = gr.Button("Analyze Sentiment")
131
+
132
+ with gr.Row():
133
+ with gr.Column():
134
+ distilbert_output = gr.Textbox(label="Predicted Sentiment (DistilBERT)", interactive=False)
135
+ log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
136
+ bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
137
+
138
+ with gr.Column():
139
+ statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
140
+
141
+ # Button to analyze sentiment and show statistics
142
+ def process_input_and_analyze(text_input):
143
+ results, statistics = analyze_sentiment_and_statistics(text_input)
144
+ return (
145
+ f"{results['DistilBERT']}",
146
+ f"{results['Logistic Regression']}",
147
+ f"{results['BERT Multilingual (NLP Town)']}",
148
+ f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
149
+ )
150
+
151
+ analyze_button.click(
152
+ process_input_and_analyze,
153
+ inputs=[text_input],
154
+ outputs=[distilbert_output, log_reg_output, bert_output, statistics_output]
155
+ )
156
+
157
+ # Launch the app
158
+ demo.launch()