Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import ( | |
pipeline, | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
AutoModelForSequenceClassification, | |
AutoModelForMaskedLM | |
) | |
import pandas as pd | |
import spacy | |
import csv | |
from io import StringIO | |
# -------------------- PAGE CONFIG -------------------- | |
st.set_page_config( | |
page_title="PuoBERTa Multi-Task Demo", | |
page_icon="🔤", | |
layout="wide" | |
) | |
# -------------------- HEADER -------------------- | |
col1, col2, col3 = st.columns([1, 2, 1]) | |
with col2: | |
try: | |
st.image("logo_transparent_small.png", width=300) | |
except: | |
st.write("🔤 PuoBERTa") | |
st.title("PuoBERTa Multi-Task Demo") | |
st.markdown(""" | |
A comprehensive demo for Setswana language models including: | |
- **Mask Filling**: Fill in missing words in sentences | |
- **POS Tagging**: Identify parts of speech | |
- **Named Entity Recognition**: Extract entities like people, places, organizations | |
- **News Classification**: Classify news articles by category | |
""") | |
st.markdown("---") | |
# -------------------- SIDEBAR -------------------- | |
st.sidebar.header("Model Information") | |
st.sidebar.markdown(""" | |
**Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai | |
**Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | |
**Huggingface Space Creators**: Vukosi Marivate, Zion Van Wyk, Unarine Netshifhefhe, Thapelo Sindane | |
**Models Used**: | |
- [dsfsi/PuoBERTa (Mask Filling - Pretrained Model)](https://huggingface.co/dsfsi/PuoBERTa) | |
- [dsfsi/PuoBERTa-POS (POS Tagging)](https://huggingface.co/dsfsi/PuoBERTa-POS) | |
- [dsfsi/PuoBERTa-NER (Named Entity Recognition)](https://huggingface.co/dsfsi/PuoBERTa-NER) | |
- [dsfsi/PuoBERTa-News (News Classification)](https://huggingface.co/dsfsi/PuoBERTa-News) | |
""") | |
# -------------------- CACHING FUNCTIONS -------------------- | |
def load_mask_filling_model(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa") | |
model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa") | |
# Create pipeline and verify mask token | |
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5) | |
# Debug: print mask token for verification | |
print(f"Mask token being used: {tokenizer.mask_token}") | |
return pipe | |
except Exception as e: | |
st.error(f"Failed to load mask filling model: {str(e)}") | |
return None | |
def load_pos_model(): | |
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS") | |
model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS") | |
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
def load_ner_model(): | |
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER") | |
model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER") | |
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
def load_news_classification_model(): | |
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News") | |
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News") | |
return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) | |
# -------------------- UTILITY FUNCTIONS -------------------- | |
def get_correct_mask_token(text, tokenizer): | |
"""Get the correct mask token format for the given tokenizer""" | |
mask_token = tokenizer.mask_token | |
# Replace common mask token formats with the correct one | |
text = text.replace("[MASK]", mask_token) | |
text = text.replace("<mask>", mask_token) | |
text = text.replace("<mask>", mask_token) | |
return text | |
# Then in your mask filling section, use: | |
# corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) | |
# results = mask_filler(corrected_input) | |
def merge_entities(output): | |
"""Merge consecutive entities of the same type""" | |
merged = [] | |
for i, ent in enumerate(output): | |
if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]: | |
merged[-1]["word"] += ent["word"] | |
merged[-1]["end"] = ent["end"] | |
else: | |
merged.append(ent) | |
return merged | |
def create_spacy_display(text, entities, task_type="ner"): | |
"""Create spaCy-style display for entities""" | |
spacy_display = {"text": text, "ents": [], "title": None} | |
for ent in entities: | |
label = ent["entity_group"] | |
if task_type == "ner" and label == "PER": | |
label = "PERSON" | |
spacy_display["ents"].append({ | |
"start": ent["start"], | |
"end": ent["end"], | |
"label": label | |
}) | |
# Define colors for different entity types | |
colors = { | |
# POS colors | |
"PRON": "#FF9999", | |
"VERB": "#99FF99", | |
"DET": "#9999FF", | |
"PROPN": "#FFFF99", | |
"CCONJ": "#FFCC99", | |
"PUNCT": "#CCCCCC", | |
"NUM": "#FFCCFF", | |
"NOUN": "#FFB366", | |
"ADJ": "#B366FF", | |
"ADP": "#66FFB3", | |
# NER colors | |
"PERSON": "#85DCDF", | |
"PER": "#85DCDF", | |
"LOC": "#DF85DC", | |
"ORG": "#DCDF85", | |
"MISC": "#85ABDF" | |
} | |
try: | |
html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True, | |
options={"colors": colors}) | |
styled_html = f""" | |
<style>mark.entity {{ display: inline-block; }}</style> | |
<div style='overflow-x:auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;'> | |
{html} | |
</div> | |
""" | |
return styled_html | |
except: | |
return "<p>Error rendering visualization</p>" | |
def get_input_text(tab_name, examples): | |
"""Get input text based on selected method""" | |
input_method = st.radio( | |
"Select Input Method", | |
['Example Text', 'Write Text', 'Upload File'], | |
key=f"{tab_name}_input_method" | |
) | |
if input_method == 'Example Text': | |
return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples") | |
elif input_method == 'Write Text': | |
return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input") | |
elif input_method == 'Upload File': | |
uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file") | |
if uploaded: | |
if uploaded.name.endswith('.csv'): | |
df = pd.read_csv(uploaded) | |
st.write("CSV Preview:", df.head()) | |
col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col") | |
return "\n".join(df[col].dropna().astype(str).tolist()) | |
else: | |
return str(uploaded.read(), "utf-8") | |
return "" | |
# -------------------- TABS -------------------- | |
tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"]) | |
# -------------------- MASK FILLING TAB -------------------- | |
with tab1: | |
st.header("Mask Filling") | |
st.write("Fill in the blanks in Setswana sentences using `<mask>` token.") | |
mask_examples = [ | |
"Ke rata go <mask> dijo tsa Batswana.", | |
"Botswana ke naga e e <mask> mo Afrika Borwa.", | |
"Bana ba <mask> sekolo ka Mosupologo.", | |
"Re tshwanetse go <mask> tikologo ya rona." | |
] | |
mask_input = get_input_text("mask", mask_examples) | |
if st.button("Fill Masks", key="mask_button") and mask_input.strip(): | |
# Check for both mask formats and convert if needed | |
if "[MASK]" in mask_input: | |
mask_input = mask_input.replace("[MASK]", "<mask>") | |
st.info("Converted [MASK] to <mask> format") | |
elif "<mask>" not in mask_input: | |
st.warning("Please include <mask> token in your text.") | |
else: | |
with st.spinner("Filling masks..."): | |
try: | |
mask_filler = load_mask_filling_model() | |
corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) | |
results = mask_filler(corrected_input) | |
# results = mask_filler(mask_input) | |
st.subheader("Predictions") | |
for i, result in enumerate(results, 1): | |
confidence = result['score'] * 100 | |
st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)") | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
# Debug information | |
st.info(f"Input text: {mask_input}") | |
try: | |
mask_filler = load_mask_filling_model() | |
st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}") | |
except: | |
pass | |
# -------------------- POS TAGGING TAB -------------------- | |
with tab2: | |
st.header("Parts of Speech Tagging") | |
st.write("Identify grammatical parts of speech in Setswana text.") | |
pos_examples = [ | |
"Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00", | |
"Batho ba le bantsi ba rata go bala dikgang tsa Setswana.", | |
"Ke ithutile Setswana kwa sekolong sa me.", | |
"Dikgomo di ja bojang mo tshimong." | |
] | |
pos_input = get_input_text("pos", pos_examples) | |
if st.button("Run POS Tagging", key="pos_button") and pos_input.strip(): | |
with st.spinner("Running POS tagging..."): | |
try: | |
pos_tagger = load_pos_model() | |
output = pos_tagger(pos_input) | |
entities = merge_entities(output) | |
if entities: | |
# Display results table | |
df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] | |
df['score'] = df['score'].round(4) | |
st.subheader("POS Tags") | |
st.dataframe(df, use_container_width=True) | |
# Visual display | |
st.subheader("Visual Display") | |
html = create_spacy_display(pos_input, entities, "pos") | |
st.markdown(html, unsafe_allow_html=True) | |
else: | |
st.info("No POS tags identified.") | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
# -------------------- NER TAB -------------------- | |
with tab3: | |
st.header("Named Entity Recognition") | |
st.write("Extract named entities like people, places, and organizations from Setswana text.") | |
ner_examples = [ | |
"Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.", | |
"Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.", | |
"Botswana Democratic Party e ne ya kopana le African National Congress.", | |
"Bank of Botswana e mo Gaborone e laola economy ya naga." | |
] | |
ner_input = get_input_text("ner", ner_examples) | |
if st.button("Run NER", key="ner_button") and ner_input.strip(): | |
with st.spinner("Running NER..."): | |
try: | |
ner_pipeline = load_ner_model() | |
output = ner_pipeline(ner_input) | |
entities = merge_entities(output) | |
if entities: | |
# Display results table | |
df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] | |
df['score'] = df['score'].round(4) | |
st.subheader("Named Entities") | |
st.dataframe(df, use_container_width=True) | |
# Visual display | |
st.subheader("Visual Display") | |
html = create_spacy_display(ner_input, entities, "ner") | |
st.markdown(html, unsafe_allow_html=True) | |
else: | |
st.info("No named entities found.") | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
# -------------------- NEWS CLASSIFICATION TAB -------------------- | |
with tab4: | |
st.header("News Classification") | |
st.write("Classify Setswana news articles into different categories.") | |
# Category mapping | |
categories = { | |
"arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang", | |
"crime_law_and_justice": "Bosenyi, molao le bosiamisi", | |
"disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso", | |
"economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete", | |
"education": "Thuto", | |
"environment": "Tikologo", | |
"health": "Boitekanelo", | |
"politics": "Dipolotiki", | |
"religion_and_belief": "Bodumedi le tumelo", | |
"society": "Setšhaba" | |
} | |
news_examples = [ | |
"Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.", | |
"Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.", | |
"Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.", | |
"Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono." | |
] | |
news_input = get_input_text("news", news_examples) | |
if st.button("Classify News", key="news_button") and news_input.strip(): | |
with st.spinner("Classifying news..."): | |
try: | |
classifier = load_news_classification_model() | |
results = classifier(news_input) | |
# Process results | |
predictions = {} | |
for pred in results[0]: | |
category_en = pred['label'] | |
category_tn = categories.get(category_en, category_en) | |
predictions[category_tn] = round(pred['score'], 4) | |
# Sort by confidence | |
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
st.subheader("Classification Results") | |
# Display as progress bars | |
for category, confidence in list(sorted_predictions.items())[:5]: | |
st.write(f"**{category}**") | |
st.progress(confidence) | |
st.write(f"Confidence: {confidence:.1%}") | |
st.write("") | |
# Display full results table | |
with st.expander("View All Categories"): | |
results_df = pd.DataFrame([ | |
{"Category": cat, "Confidence": conf} | |
for cat, conf in sorted_predictions.items() | |
]) | |
st.dataframe(results_df, use_container_width=True) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
# -------------------- FOOTER -------------------- | |
st.markdown("---") | |
st.markdown(""" | |
### 📚 Citation | |
```bibtex | |
@inproceedings{marivate2023puoberta, | |
title = {PuoBERTa: Training and evaluation of a curated language model for Setswana}, | |
author = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai}, | |
year = {2023}, | |
booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science}, | |
url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17}, | |
keywords = {NLP}, | |
preprint_url = {https://arxiv.org/abs/2310.09141}, | |
dataset_url = {https://github.com/dsfsi/PuoBERTa}, | |
software_url = {https://huggingface.co/dsfsi/PuoBERTa} | |
} | |
``` | |
**Links**: [Paper](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa) | |
""") |