Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, render_template | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import torch.nn.functional as F | |
| import re | |
| from flask_cors import CORS # Enable CORS | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| CORS(app) # Allow requests from frontend apps | |
| # Choose your model: 'bert-base-uncased' or 'GroNLP/hateBERT' | |
| MODEL_NAME = 'bert-base-uncased' # Change to 'GroNLP/hateBERT' if needed | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| # Two-class labels only | |
| LABELS = ['Safe', 'Cyberbullying'] | |
| # Offensive trigger words | |
| TRIGGER_WORDS = [ | |
| "gago", "pokpok", "yawa", "linte", "ulol", "tangina", "bilat", "putang", "tarantado", "bobo", | |
| "yudipota", "law-ay", "bilatibay", "hayop" | |
| ] | |
| # Detect trigger words in input text | |
| def find_triggers(text): | |
| found = [] | |
| for word in TRIGGER_WORDS: | |
| if re.search(rf"\b{re.escape(word)}\b", text, re.IGNORECASE): | |
| found.append(word) | |
| return found | |
| # Predict function | |
| def predict_text(text): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = F.softmax(logits, dim=1) | |
| confidence, predicted_class = torch.max(probs, dim=1) | |
| # Fallback logic: if model predicts more than 2 classes, default to Safe if out-of-bounds | |
| label_index = predicted_class.item() | |
| if label_index >= len(LABELS): | |
| label_index = 0 # default to "Safe" | |
| label = LABELS[label_index] | |
| confidence_score = round(confidence.item(), 4) | |
| triggers = find_triggers(text) | |
| # Override model prediction if offensive triggers found | |
| if triggers and label == "Safe": | |
| label = "Cyberbullying" | |
| return { | |
| "label": label, | |
| "confidence": confidence_score, | |
| "triggers": triggers | |
| } | |
| # Serve frontend | |
| def index(): | |
| return render_template('index.html') # Ensure templates/index.html exists | |
| # API endpoint | |
| def predict_api(): | |
| try: | |
| data = request.get_json() | |
| text = data.get("text", "") | |
| if not text.strip(): | |
| return jsonify({"error": "No text provided"}), 400 | |
| result = predict_text(text) | |
| return jsonify(result) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| # Run server | |
| if __name__ == "__main__": | |
| app.run(debug=True) | |