Spaces:
Runtime error
Runtime error
import sys | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
# Directory where the model and tokenizer files are located | |
model_dir = "C:/Users/Biruh/amharic_hate_speech_detection/amharic-hate-speech-backend" # Update path if necessary | |
# Load tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained(model_dir) | |
model = BertForSequenceClassification.from_pretrained(model_dir) | |
# Get the text input from the command-line arguments | |
if len(sys.argv) > 1: | |
text = sys.argv[1] | |
else: | |
text = "ምንም ጽሁፍ አልተሰጠም" # Default text in case no input is provided | |
# Tokenize the input text | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
# Make prediction without gradients | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract the prediction (the logits) and convert it to a class (1 = Hate Speech, 0 = Non-Hate Speech) | |
logits = outputs.logits | |
prediction = torch.argmax(logits, dim=-1).item() | |
# Output the prediction result | |
print("Hate Speech" if prediction == 1 else "Non-Hate Speech") | |