aieeshashafique's picture
Update app.py
2bb51ab verified
raw
history blame
2 kB
import os
import torch
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TextClassificationPipeline
)
# === Config ===
MODEL_ID = "Omartificial-Intelligence-Space/SA-BERT-Classifier"
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
DEVICE = 0 if torch.cuda.is_available() else -1
# === Load model and tokenizer ===
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID, use_auth_token=HF_TOKEN
).to("cuda" if DEVICE == 0 else "cpu")
# === Build pipeline ===
pipeline = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
device=DEVICE,
top_k=None # replaces deprecated return_all_scores
)
# === Inference function ===
def classify_dialect(text):
results = pipeline(text)[0]
scores = {int(item["label"].split("_")[-1]): item["score"] for item in results}
p_non_saudi = scores.get(0, 0.0)
p_saudi = scores.get(1, 0.0)
prediction = "Saudi Dialect" if p_saudi > p_non_saudi else "Non-Saudi Dialect"
return round(p_saudi, 4), round(p_non_saudi, 4), prediction
# === Gradio Interface ===
demo = gr.Interface(
fn=classify_dialect,
inputs=gr.Textbox(lines=2, placeholder="ุงูƒุชุจ ุฌู…ู„ุฉ ุจุงู„ู„ู‡ุฌุฉ ุงู„ุนุฑุจูŠุฉ ู‡ู†ุง..."),
outputs=[
gr.Label(label="Saudi Dialect (Probability)"),
gr.Label(label="Non-Saudi Dialect (Probability)"),
gr.Textbox(label="Final Prediction")
],
title="๐Ÿ—ฃ๏ธ Saudi Dialect Classifier",
description="๐Ÿ” ู†ู…ูˆุฐุฌ BERT ู„ุชุตู†ูŠู ุงู„ุฌู…ู„ ุฅู„ู‰ ู„ู‡ุฌุฉ ุณุนูˆุฏูŠุฉ ุฃูˆ ุบูŠุฑ ุณุนูˆุฏูŠุฉ.\n\n๐Ÿ‘ฉโ€๐Ÿ’ป Deployed by **Ayesha Shafique** [LinkedIn](https://www.linkedin.com/in/aieeshashafique/)\n\n๐ŸŒ Model credit: [Omartificial-Intelligence-Space](https://huggingface.co/Omartificial-Intelligence-Space)",
allow_flagging="never"
)
# === Launch App ===
if __name__ == "__main__":
demo.launch()