nhull commited on
Commit
457d325
·
verified ·
1 Parent(s): e4627b7

Add Roberta model

Browse files
Files changed (1) hide show
  1. app.py +26 -1
app.py CHANGED
@@ -24,6 +24,10 @@ models = {
24
  "TinyBERT": {
25
  "tokenizer": AutoTokenizer.from_pretrained("elo4/TinyBERT-sentiment-model"),
26
  "model": AutoModelForSequenceClassification.from_pretrained("elo4/TinyBERT-sentiment-model"),
 
 
 
 
27
  }
28
  }
29
 
@@ -82,6 +86,16 @@ def predict_with_tinybert(text):
82
  predictions = logits.argmax(axis=-1).cpu().numpy()
83
  return int(predictions[0] + 1)
84
 
 
 
 
 
 
 
 
 
 
 
85
  # Unified function for sentiment analysis and statistics
86
  def analyze_sentiment_and_statistics(text):
87
  results = {
@@ -89,6 +103,7 @@ def analyze_sentiment_and_statistics(text):
89
  "Logistic Regression": predict_with_logistic_regression(text),
90
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
91
  "TinyBERT": predict_with_tinybert(text),
 
92
  }
93
 
94
  # Calculate statistics
@@ -156,6 +171,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
156
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
157
  bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
158
  tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
 
159
 
160
  with gr.Column():
161
  statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
@@ -169,6 +185,7 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
169
  f"{results['Logistic Regression']}",
170
  f"{results['BERT Multilingual (NLP Town)']}",
171
  f"{results['TinyBERT']}",
 
172
  f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
173
  )
174
  else: # Min and Max scores are present
@@ -177,13 +194,21 @@ with gr.Blocks(css=".gradio-container { max-width: 900px; margin: auto; padding:
177
  f"{results['Logistic Regression']}",
178
  f"{results['BERT Multilingual (NLP Town)']}",
179
  f"{results['TinyBERT']}",
 
180
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
181
  )
182
 
183
  analyze_button.click(
184
  process_input_and_analyze,
185
  inputs=[text_input],
186
- outputs=[distilbert_output, log_reg_output, bert_output, tinybert_output, statistics_output]
 
 
 
 
 
 
 
187
  )
188
 
189
  # Launch the app
 
24
  "TinyBERT": {
25
  "tokenizer": AutoTokenizer.from_pretrained("elo4/TinyBERT-sentiment-model"),
26
  "model": AutoModelForSequenceClassification.from_pretrained("elo4/TinyBERT-sentiment-model"),
27
+ },
28
+ "RoBERTa": {
29
+ "tokenizer": AutoTokenizer.from_pretrained("ordek899/roberta_1to5rating_pred_for_restaur_trained_on_hotels"),
30
+ "model": AutoModelForSequenceClassification.from_pretrained("ordek899/roberta_1to5rating_pred_for_restaur_trained_on_hotels"),
31
  }
32
  }
33
 
 
86
  predictions = logits.argmax(axis=-1).cpu().numpy()
87
  return int(predictions[0] + 1)
88
 
89
+ def predict_with_roberta_ordek899(text):
90
+ tokenizer = models["RoBERTa"]["tokenizer"]
91
+ model = models["RoBERTa"]["model"]
92
+ encodings = tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
93
+ with torch.no_grad():
94
+ outputs = model(**encodings)
95
+ logits = outputs.logits
96
+ predictions = logits.argmax(axis=-1).cpu().numpy()
97
+ return int(predictions[0] + 1)
98
+
99
  # Unified function for sentiment analysis and statistics
100
  def analyze_sentiment_and_statistics(text):
101
  results = {
 
103
  "Logistic Regression": predict_with_logistic_regression(text),
104
  "BERT Multilingual (NLP Town)": predict_with_bert_multilingual(text),
105
  "TinyBERT": predict_with_tinybert(text),
106
+ "RoBERTa": predict_with_roberta_ordek899(text),
107
  }
108
 
109
  # Calculate statistics
 
171
  log_reg_output = gr.Textbox(label="Predicted Sentiment (Logistic Regression)", interactive=False)
172
  bert_output = gr.Textbox(label="Predicted Sentiment (BERT Multilingual)", interactive=False)
173
  tinybert_output = gr.Textbox(label="Predicted Sentiment (TinyBERT)", interactive=False)
174
+ roberta_ordek_output = gr.Textbox(label="Predicted Sentiment (RoBERTa)", interactive=False)
175
 
176
  with gr.Column():
177
  statistics_output = gr.Textbox(label="Statistics (Lowest, Highest, Average)", interactive=False)
 
185
  f"{results['Logistic Regression']}",
186
  f"{results['BERT Multilingual (NLP Town)']}",
187
  f"{results['TinyBERT']}",
188
+ f"{results['RoBERTa']}",
189
  f"Statistics:\n{statistics['Message']}\nAverage Score: {statistics['Average Score']}"
190
  )
191
  else: # Min and Max scores are present
 
194
  f"{results['Logistic Regression']}",
195
  f"{results['BERT Multilingual (NLP Town)']}",
196
  f"{results['TinyBERT']}",
197
+ f"{results['RoBERTa']}",
198
  f"Statistics:\n{statistics['Lowest Score']}\n{statistics['Highest Score']}\nAverage Score: {statistics['Average Score']}"
199
  )
200
 
201
  analyze_button.click(
202
  process_input_and_analyze,
203
  inputs=[text_input],
204
+ outputs=[
205
+ distilbert_output,
206
+ log_reg_output,
207
+ bert_output,
208
+ tinybert_output,
209
+ roberta_ordek_output,
210
+ statistics_output
211
+ ]
212
  )
213
 
214
  # Launch the app