|
import torch |
|
import json |
|
from huggingface_hub import hf_hub_download |
|
import re |
|
import emoji |
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
|
def preprocess_text(text): |
|
"""Preprocess the input text to match training conditions.""" |
|
text = re.sub(r'u/\w+', '[USER]', text) |
|
text = re.sub(r'r/\w+', '[SUBREDDIT]', text) |
|
text = re.sub(r'http[s]?://\S+', '[URL]', text) |
|
text = emoji.demojize(text, delimiters=(" ", " ")) |
|
text = text.lower() |
|
return text |
|
|
|
def load_model_and_resources(): |
|
"""Load the model, tokenizer, emotion labels, and thresholds from Hugging Face.""" |
|
repo_id = "logasanjeev/emotions-analyzer-bert" |
|
|
|
try: |
|
model = BertForSequenceClassification.from_pretrained(repo_id) |
|
tokenizer = BertTokenizer.from_pretrained(repo_id) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model/tokenizer: {str(e)}") |
|
|
|
try: |
|
thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json") |
|
with open(thresholds_file, "r") as f: |
|
thresholds_data = json.load(f) |
|
if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data): |
|
raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.") |
|
emotion_labels = thresholds_data["emotion_labels"] |
|
thresholds = thresholds_data["thresholds"] |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading thresholds: {str(e)}") |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
model.eval() |
|
|
|
return model, tokenizer, emotion_labels, thresholds, device |
|
|
|
MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None |
|
|
|
def predict_emotions(text): |
|
"""Predict emotions for the given text using the GoEmotions BERT model. |
|
|
|
Args: |
|
text (str): The input text to analyze. |
|
|
|
Returns: |
|
tuple: (predictions, processed_text) |
|
- predictions (str): Formatted string of predicted emotions and their confidence scores. |
|
- processed_text (str): The preprocessed input text. |
|
""" |
|
global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE |
|
|
|
if MODEL is None: |
|
MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources() |
|
|
|
processed_text = preprocess_text(text) |
|
|
|
encodings = TOKENIZER( |
|
processed_text, |
|
padding='max_length', |
|
truncation=True, |
|
max_length=128, |
|
return_tensors='pt' |
|
) |
|
|
|
input_ids = encodings['input_ids'].to(DEVICE) |
|
attention_mask = encodings['attention_mask'].to(DEVICE) |
|
|
|
with torch.no_grad(): |
|
outputs = MODEL(input_ids, attention_mask=attention_mask) |
|
logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] |
|
|
|
predictions = [] |
|
for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)): |
|
if logit >= thresh: |
|
predictions.append((EMOTION_LABELS[i], round(logit, 4))) |
|
|
|
predictions.sort(key=lambda x: x[1], reverse=True) |
|
|
|
result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted." |
|
return result, processed_text |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.") |
|
parser.add_argument("text", type=str, help="The input text to analyze for emotions.") |
|
args = parser.parse_args() |
|
|
|
result, processed = predict_emotions(args.text) |
|
print(f"Input: {args.text}") |
|
print(f"Processed: {processed}") |
|
print("Predicted Emotions:") |
|
print(result) |