emotions-analyzer-bert / inference.py
logasanjeev's picture
Update inference.py
f68271b verified
import torch
import json
from huggingface_hub import hf_hub_download
import re
import emoji
from transformers import BertForSequenceClassification, BertTokenizer
def preprocess_text(text):
"""Preprocess the input text to match training conditions."""
text = re.sub(r'u/\w+', '[USER]', text)
text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
text = re.sub(r'http[s]?://\S+', '[URL]', text)
text = emoji.demojize(text, delimiters=(" ", " "))
text = text.lower()
return text
def load_model_and_resources():
"""Load the model, tokenizer, emotion labels, and thresholds from Hugging Face."""
repo_id = "logasanjeev/emotions-analyzer-bert"
try:
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
except Exception as e:
raise RuntimeError(f"Error loading model/tokenizer: {str(e)}")
try:
thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json")
with open(thresholds_file, "r") as f:
thresholds_data = json.load(f)
if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
except Exception as e:
raise RuntimeError(f"Error loading thresholds: {str(e)}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
return model, tokenizer, emotion_labels, thresholds, device
MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None
def predict_emotions(text):
"""Predict emotions for the given text using the GoEmotions BERT model.
Args:
text (str): The input text to analyze.
Returns:
tuple: (predictions, processed_text)
- predictions (str): Formatted string of predicted emotions and their confidence scores.
- processed_text (str): The preprocessed input text.
"""
global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE
if MODEL is None:
MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources()
processed_text = preprocess_text(text)
encodings = TOKENIZER(
processed_text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors='pt'
)
input_ids = encodings['input_ids'].to(DEVICE)
attention_mask = encodings['attention_mask'].to(DEVICE)
with torch.no_grad():
outputs = MODEL(input_ids, attention_mask=attention_mask)
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
predictions = []
for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
if logit >= thresh:
predictions.append((EMOTION_LABELS[i], round(logit, 4)))
predictions.sort(key=lambda x: x[1], reverse=True)
result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
return result, processed_text
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.")
parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
args = parser.parse_args()
result, processed = predict_emotions(args.text)
print(f"Input: {args.text}")
print(f"Processed: {processed}")
print("Predicted Emotions:")
print(result)