natishanau's picture
Initial FastAPI backend for Amharic Fake News Detector
af0d9a9
import sys
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# Directory where the model and tokenizer files are located
model_dir = "C:/Users/Biruh/amharic_hate_speech_detection/amharic-hate-speech-backend" # Update path if necessary
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_dir)
model = BertForSequenceClassification.from_pretrained(model_dir)
# Get the text input from the command-line arguments
if len(sys.argv) > 1:
text = sys.argv[1]
else:
text = "ምንም ጽሁፍ አልተሰጠም" # Default text in case no input is provided
# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Make prediction without gradients
with torch.no_grad():
outputs = model(**inputs)
# Extract the prediction (the logits) and convert it to a class (1 = Hate Speech, 0 = Non-Hate Speech)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1).item()
# Output the prediction result
print("Hate Speech" if prediction == 1 else "Non-Hate Speech")