Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
TextClassificationPipeline | |
) | |
# === Config === | |
MODEL_ID = "Omartificial-Intelligence-Space/SA-BERT-Classifier" | |
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
DEVICE = 0 if torch.cuda.is_available() else -1 | |
# === Load model and tokenizer === | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN) | |
model = AutoModelForSequenceClassification.from_pretrained( | |
MODEL_ID, use_auth_token=HF_TOKEN | |
).to("cuda" if DEVICE == 0 else "cpu") | |
# === Build pipeline === | |
pipeline = TextClassificationPipeline( | |
model=model, | |
tokenizer=tokenizer, | |
device=DEVICE, | |
top_k=None # replaces deprecated return_all_scores | |
) | |
# === Inference function === | |
def classify_dialect(text): | |
results = pipeline(text)[0] | |
scores = {int(item["label"].split("_")[-1]): item["score"] for item in results} | |
p_non_saudi = scores.get(0, 0.0) | |
p_saudi = scores.get(1, 0.0) | |
prediction = "Saudi Dialect" if p_saudi > p_non_saudi else "Non-Saudi Dialect" | |
return round(p_saudi, 4), round(p_non_saudi, 4), prediction | |
# === Gradio Interface === | |
demo = gr.Interface( | |
fn=classify_dialect, | |
inputs=gr.Textbox(lines=2, placeholder="ุงูุชุจ ุฌู ูุฉ ุจุงูููุฌุฉ ุงูุนุฑุจูุฉ ููุง..."), | |
outputs=[ | |
gr.Label(label="Saudi Dialect (Probability)"), | |
gr.Label(label="Non-Saudi Dialect (Probability)"), | |
gr.Textbox(label="Final Prediction") | |
], | |
title="๐ฃ๏ธ Saudi Dialect Classifier", | |
description="๐ ูู ูุฐุฌ BERT ูุชุตููู ุงูุฌู ู ุฅูู ููุฌุฉ ุณุนูุฏูุฉ ุฃู ุบูุฑ ุณุนูุฏูุฉ.\n\n๐ฉโ๐ป Deployed by **Ayesha Shafique** [LinkedIn](https://www.linkedin.com/in/aieeshashafique/)\n\n๐ Model credit: [Omartificial-Intelligence-Space](https://huggingface.co/Omartificial-Intelligence-Space)", | |
allow_flagging="never" | |
) | |
# === Launch App === | |
if __name__ == "__main__": | |
demo.launch() | |