PuoBERTaSpace / app.py
vukosi's picture
Fixed some formatting.
c4a32ef
raw
history blame
16.4 kB
import streamlit as st
from transformers import (
pipeline,
AutoModelForTokenClassification,
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModelForMaskedLM
)
import pandas as pd
import spacy
import csv
from io import StringIO
# -------------------- PAGE CONFIG --------------------
st.set_page_config(
page_title="PuoBERTa Multi-Task Demo",
page_icon="🔤",
layout="wide"
)
# -------------------- HEADER --------------------
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
try:
st.image("logo_transparent_small.png", width=300)
except:
st.write("🔤 PuoBERTa")
st.title("PuoBERTa Multi-Task Demo")
st.markdown("""
A comprehensive demo for Setswana language models including:
- **Mask Filling**: Fill in missing words in sentences
- **POS Tagging**: Identify parts of speech
- **Named Entity Recognition**: Extract entities like people, places, organizations
- **News Classification**: Classify news articles by category
""")
st.markdown("---")
# -------------------- SIDEBAR --------------------
st.sidebar.header("Model Information")
st.sidebar.markdown("""
**Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai
**Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141)
**Huggingface Space Creators**: Vukosi Marivate, Zion Van Wyk, Unarine Netshifhefhe, Thapelo Sindane
**Models Used**:
- [dsfsi/PuoBERTa (Mask Filling - Pretrained Model)](https://huggingface.co/dsfsi/PuoBERTa)
- [dsfsi/PuoBERTa-POS (POS Tagging)](https://huggingface.co/dsfsi/PuoBERTa-POS)
- [dsfsi/PuoBERTa-NER (Named Entity Recognition)](https://huggingface.co/dsfsi/PuoBERTa-NER)
- [dsfsi/PuoBERTa-News (News Classification)](https://huggingface.co/dsfsi/PuoBERTa-News)
""")
# -------------------- CACHING FUNCTIONS --------------------
@st.cache_resource
def load_mask_filling_model():
try:
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
# Create pipeline and verify mask token
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
# Debug: print mask token for verification
print(f"Mask token being used: {tokenizer.mask_token}")
return pipe
except Exception as e:
st.error(f"Failed to load mask filling model: {str(e)}")
return None
@st.cache_resource
def load_pos_model():
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS")
model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS")
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
@st.cache_resource
def load_ner_model():
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER")
model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER")
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
@st.cache_resource
def load_news_classification_model():
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
# -------------------- UTILITY FUNCTIONS --------------------
def get_correct_mask_token(text, tokenizer):
"""Get the correct mask token format for the given tokenizer"""
mask_token = tokenizer.mask_token
# Replace common mask token formats with the correct one
text = text.replace("[MASK]", mask_token)
text = text.replace("<mask>", mask_token)
text = text.replace("&lt;mask&gt;", mask_token)
return text
# Then in your mask filling section, use:
# corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
# results = mask_filler(corrected_input)
def merge_entities(output):
"""Merge consecutive entities of the same type"""
merged = []
for i, ent in enumerate(output):
if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
merged[-1]["word"] += ent["word"]
merged[-1]["end"] = ent["end"]
else:
merged.append(ent)
return merged
def create_spacy_display(text, entities, task_type="ner"):
"""Create spaCy-style display for entities"""
spacy_display = {"text": text, "ents": [], "title": None}
for ent in entities:
label = ent["entity_group"]
if task_type == "ner" and label == "PER":
label = "PERSON"
spacy_display["ents"].append({
"start": ent["start"],
"end": ent["end"],
"label": label
})
# Define colors for different entity types
colors = {
# POS colors
"PRON": "#FF9999",
"VERB": "#99FF99",
"DET": "#9999FF",
"PROPN": "#FFFF99",
"CCONJ": "#FFCC99",
"PUNCT": "#CCCCCC",
"NUM": "#FFCCFF",
"NOUN": "#FFB366",
"ADJ": "#B366FF",
"ADP": "#66FFB3",
# NER colors
"PERSON": "#85DCDF",
"PER": "#85DCDF",
"LOC": "#DF85DC",
"ORG": "#DCDF85",
"MISC": "#85ABDF"
}
try:
html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True,
options={"colors": colors})
styled_html = f"""
<style>mark.entity {{ display: inline-block; }}</style>
<div style='overflow-x:auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;'>
{html}
</div>
"""
return styled_html
except:
return "<p>Error rendering visualization</p>"
def get_input_text(tab_name, examples):
"""Get input text based on selected method"""
input_method = st.radio(
"Select Input Method",
['Example Text', 'Write Text', 'Upload File'],
key=f"{tab_name}_input_method"
)
if input_method == 'Example Text':
return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples")
elif input_method == 'Write Text':
return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input")
elif input_method == 'Upload File':
uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file")
if uploaded:
if uploaded.name.endswith('.csv'):
df = pd.read_csv(uploaded)
st.write("CSV Preview:", df.head())
col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col")
return "\n".join(df[col].dropna().astype(str).tolist())
else:
return str(uploaded.read(), "utf-8")
return ""
# -------------------- TABS --------------------
tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"])
# -------------------- MASK FILLING TAB --------------------
with tab1:
st.header("Mask Filling")
st.write("Fill in the blanks in Setswana sentences using `<mask>` token.")
mask_examples = [
"Ke rata go <mask> dijo tsa Batswana.",
"Botswana ke naga e e <mask> mo Afrika Borwa.",
"Bana ba <mask> sekolo ka Mosupologo.",
"Re tshwanetse go <mask> tikologo ya rona."
]
mask_input = get_input_text("mask", mask_examples)
if st.button("Fill Masks", key="mask_button") and mask_input.strip():
# Check for both mask formats and convert if needed
if "[MASK]" in mask_input:
mask_input = mask_input.replace("[MASK]", "<mask>")
st.info("Converted [MASK] to <mask> format")
elif "<mask>" not in mask_input:
st.warning("Please include <mask> token in your text.")
else:
with st.spinner("Filling masks..."):
try:
mask_filler = load_mask_filling_model()
corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
results = mask_filler(corrected_input)
# results = mask_filler(mask_input)
st.subheader("Predictions")
for i, result in enumerate(results, 1):
confidence = result['score'] * 100
st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)")
except Exception as e:
st.error(f"Error: {str(e)}")
# Debug information
st.info(f"Input text: {mask_input}")
try:
mask_filler = load_mask_filling_model()
st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}")
except:
pass
# -------------------- POS TAGGING TAB --------------------
with tab2:
st.header("Parts of Speech Tagging")
st.write("Identify grammatical parts of speech in Setswana text.")
pos_examples = [
"Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00",
"Batho ba le bantsi ba rata go bala dikgang tsa Setswana.",
"Ke ithutile Setswana kwa sekolong sa me.",
"Dikgomo di ja bojang mo tshimong."
]
pos_input = get_input_text("pos", pos_examples)
if st.button("Run POS Tagging", key="pos_button") and pos_input.strip():
with st.spinner("Running POS tagging..."):
try:
pos_tagger = load_pos_model()
output = pos_tagger(pos_input)
entities = merge_entities(output)
if entities:
# Display results table
df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
df['score'] = df['score'].round(4)
st.subheader("POS Tags")
st.dataframe(df, use_container_width=True)
# Visual display
st.subheader("Visual Display")
html = create_spacy_display(pos_input, entities, "pos")
st.markdown(html, unsafe_allow_html=True)
else:
st.info("No POS tags identified.")
except Exception as e:
st.error(f"Error: {str(e)}")
# -------------------- NER TAB --------------------
with tab3:
st.header("Named Entity Recognition")
st.write("Extract named entities like people, places, and organizations from Setswana text.")
ner_examples = [
"Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.",
"Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.",
"Botswana Democratic Party e ne ya kopana le African National Congress.",
"Bank of Botswana e mo Gaborone e laola economy ya naga."
]
ner_input = get_input_text("ner", ner_examples)
if st.button("Run NER", key="ner_button") and ner_input.strip():
with st.spinner("Running NER..."):
try:
ner_pipeline = load_ner_model()
output = ner_pipeline(ner_input)
entities = merge_entities(output)
if entities:
# Display results table
df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
df['score'] = df['score'].round(4)
st.subheader("Named Entities")
st.dataframe(df, use_container_width=True)
# Visual display
st.subheader("Visual Display")
html = create_spacy_display(ner_input, entities, "ner")
st.markdown(html, unsafe_allow_html=True)
else:
st.info("No named entities found.")
except Exception as e:
st.error(f"Error: {str(e)}")
# -------------------- NEWS CLASSIFICATION TAB --------------------
with tab4:
st.header("News Classification")
st.write("Classify Setswana news articles into different categories.")
# Category mapping
categories = {
"arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
"crime_law_and_justice": "Bosenyi, molao le bosiamisi",
"disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
"economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
"education": "Thuto",
"environment": "Tikologo",
"health": "Boitekanelo",
"politics": "Dipolotiki",
"religion_and_belief": "Bodumedi le tumelo",
"society": "Setšhaba"
}
news_examples = [
"Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.",
"Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.",
"Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.",
"Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono."
]
news_input = get_input_text("news", news_examples)
if st.button("Classify News", key="news_button") and news_input.strip():
with st.spinner("Classifying news..."):
try:
classifier = load_news_classification_model()
results = classifier(news_input)
# Process results
predictions = {}
for pred in results[0]:
category_en = pred['label']
category_tn = categories.get(category_en, category_en)
predictions[category_tn] = round(pred['score'], 4)
# Sort by confidence
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
st.subheader("Classification Results")
# Display as progress bars
for category, confidence in list(sorted_predictions.items())[:5]:
st.write(f"**{category}**")
st.progress(confidence)
st.write(f"Confidence: {confidence:.1%}")
st.write("")
# Display full results table
with st.expander("View All Categories"):
results_df = pd.DataFrame([
{"Category": cat, "Confidence": conf}
for cat, conf in sorted_predictions.items()
])
st.dataframe(results_df, use_container_width=True)
except Exception as e:
st.error(f"Error: {str(e)}")
# -------------------- FOOTER --------------------
st.markdown("---")
st.markdown("""
### 📚 Citation
```bibtex
@inproceedings{marivate2023puoberta,
title = {PuoBERTa: Training and evaluation of a curated language model for Setswana},
author = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai},
year = {2023},
booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science},
url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17},
keywords = {NLP},
preprint_url = {https://arxiv.org/abs/2310.09141},
dataset_url = {https://github.com/dsfsi/PuoBERTa},
software_url = {https://huggingface.co/dsfsi/PuoBERTa}
}
```
**Links**: [Paper](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa)
""")