import streamlit as st from transformers import ( pipeline, AutoModelForTokenClassification, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForMaskedLM ) import pandas as pd import spacy import csv from io import StringIO # -------------------- PAGE CONFIG -------------------- st.set_page_config( page_title="PuoBERTa Multi-Task Demo", page_icon="🔤", layout="wide" ) # -------------------- HEADER -------------------- col1, col2, col3 = st.columns([1, 2, 1]) with col2: try: st.image("logo_transparent_small.png", width=300) except: st.write("🔤 PuoBERTa") st.title("PuoBERTa Multi-Task Demo") st.markdown(""" A comprehensive demo for Setswana language models including: - **Mask Filling**: Fill in missing words in sentences - **POS Tagging**: Identify parts of speech - **Named Entity Recognition**: Extract entities like people, places, organizations - **News Classification**: Classify news articles by category """) st.markdown("---") # -------------------- SIDEBAR -------------------- st.sidebar.header("Model Information") st.sidebar.markdown(""" **Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai **Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) **Huggingface Space Creators**: Vukosi Marivate, Zion Van Wyk, Unarine Netshifhefhe, Thapelo Sindane **Models Used**: - [dsfsi/PuoBERTa (Mask Filling - Pretrained Model)](https://huggingface.co/dsfsi/PuoBERTa) - [dsfsi/PuoBERTa-POS (POS Tagging)](https://huggingface.co/dsfsi/PuoBERTa-POS) - [dsfsi/PuoBERTa-NER (Named Entity Recognition)](https://huggingface.co/dsfsi/PuoBERTa-NER) - [dsfsi/PuoBERTa-News (News Classification)](https://huggingface.co/dsfsi/PuoBERTa-News) """) # -------------------- CACHING FUNCTIONS -------------------- @st.cache_resource def load_mask_filling_model(): try: tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa") model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa") # Create pipeline and verify mask token pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5) # Debug: print mask token for verification print(f"Mask token being used: {tokenizer.mask_token}") return pipe except Exception as e: st.error(f"Failed to load mask filling model: {str(e)}") return None @st.cache_resource def load_pos_model(): tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS") model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS") return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") @st.cache_resource def load_ner_model(): tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER") model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER") return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") @st.cache_resource def load_news_classification_model(): tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News") model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News") return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) # -------------------- UTILITY FUNCTIONS -------------------- def get_correct_mask_token(text, tokenizer): """Get the correct mask token format for the given tokenizer""" mask_token = tokenizer.mask_token # Replace common mask token formats with the correct one text = text.replace("[MASK]", mask_token) text = text.replace("", mask_token) text = text.replace("<mask>", mask_token) return text # Then in your mask filling section, use: # corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) # results = mask_filler(corrected_input) def merge_entities(output): """Merge consecutive entities of the same type""" merged = [] for i, ent in enumerate(output): if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]: merged[-1]["word"] += ent["word"] merged[-1]["end"] = ent["end"] else: merged.append(ent) return merged def create_spacy_display(text, entities, task_type="ner"): """Create spaCy-style display for entities""" spacy_display = {"text": text, "ents": [], "title": None} for ent in entities: label = ent["entity_group"] if task_type == "ner" and label == "PER": label = "PERSON" spacy_display["ents"].append({ "start": ent["start"], "end": ent["end"], "label": label }) # Define colors for different entity types colors = { # POS colors "PRON": "#FF9999", "VERB": "#99FF99", "DET": "#9999FF", "PROPN": "#FFFF99", "CCONJ": "#FFCC99", "PUNCT": "#CCCCCC", "NUM": "#FFCCFF", "NOUN": "#FFB366", "ADJ": "#B366FF", "ADP": "#66FFB3", # NER colors "PERSON": "#85DCDF", "PER": "#85DCDF", "LOC": "#DF85DC", "ORG": "#DCDF85", "MISC": "#85ABDF" } try: html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True, options={"colors": colors}) styled_html = f"""
{html}
""" return styled_html except: return "

Error rendering visualization

" def get_input_text(tab_name, examples): """Get input text based on selected method""" input_method = st.radio( "Select Input Method", ['Example Text', 'Write Text', 'Upload File'], key=f"{tab_name}_input_method" ) if input_method == 'Example Text': return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples") elif input_method == 'Write Text': return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input") elif input_method == 'Upload File': uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file") if uploaded: if uploaded.name.endswith('.csv'): df = pd.read_csv(uploaded) st.write("CSV Preview:", df.head()) col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col") return "\n".join(df[col].dropna().astype(str).tolist()) else: return str(uploaded.read(), "utf-8") return "" # -------------------- TABS -------------------- tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"]) # -------------------- MASK FILLING TAB -------------------- with tab1: st.header("Mask Filling") st.write("Fill in the blanks in Setswana sentences using `` token.") mask_examples = [ "Ke rata go dijo tsa Batswana.", "Botswana ke naga e e mo Afrika Borwa.", "Bana ba sekolo ka Mosupologo.", "Re tshwanetse go tikologo ya rona." ] mask_input = get_input_text("mask", mask_examples) if st.button("Fill Masks", key="mask_button") and mask_input.strip(): # Check for both mask formats and convert if needed if "[MASK]" in mask_input: mask_input = mask_input.replace("[MASK]", "") st.info("Converted [MASK] to format") elif "" not in mask_input: st.warning("Please include token in your text.") else: with st.spinner("Filling masks..."): try: mask_filler = load_mask_filling_model() corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) results = mask_filler(corrected_input) # results = mask_filler(mask_input) st.subheader("Predictions") for i, result in enumerate(results, 1): confidence = result['score'] * 100 st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)") except Exception as e: st.error(f"Error: {str(e)}") # Debug information st.info(f"Input text: {mask_input}") try: mask_filler = load_mask_filling_model() st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}") except: pass # -------------------- POS TAGGING TAB -------------------- with tab2: st.header("Parts of Speech Tagging") st.write("Identify grammatical parts of speech in Setswana text.") pos_examples = [ "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00", "Batho ba le bantsi ba rata go bala dikgang tsa Setswana.", "Ke ithutile Setswana kwa sekolong sa me.", "Dikgomo di ja bojang mo tshimong." ] pos_input = get_input_text("pos", pos_examples) if st.button("Run POS Tagging", key="pos_button") and pos_input.strip(): with st.spinner("Running POS tagging..."): try: pos_tagger = load_pos_model() output = pos_tagger(pos_input) entities = merge_entities(output) if entities: # Display results table df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] df['score'] = df['score'].round(4) st.subheader("POS Tags") st.dataframe(df, use_container_width=True) # Visual display st.subheader("Visual Display") html = create_spacy_display(pos_input, entities, "pos") st.markdown(html, unsafe_allow_html=True) else: st.info("No POS tags identified.") except Exception as e: st.error(f"Error: {str(e)}") # -------------------- NER TAB -------------------- with tab3: st.header("Named Entity Recognition") st.write("Extract named entities like people, places, and organizations from Setswana text.") ner_examples = [ "Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.", "Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.", "Botswana Democratic Party e ne ya kopana le African National Congress.", "Bank of Botswana e mo Gaborone e laola economy ya naga." ] ner_input = get_input_text("ner", ner_examples) if st.button("Run NER", key="ner_button") and ner_input.strip(): with st.spinner("Running NER..."): try: ner_pipeline = load_ner_model() output = ner_pipeline(ner_input) entities = merge_entities(output) if entities: # Display results table df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] df['score'] = df['score'].round(4) st.subheader("Named Entities") st.dataframe(df, use_container_width=True) # Visual display st.subheader("Visual Display") html = create_spacy_display(ner_input, entities, "ner") st.markdown(html, unsafe_allow_html=True) else: st.info("No named entities found.") except Exception as e: st.error(f"Error: {str(e)}") # -------------------- NEWS CLASSIFICATION TAB -------------------- with tab4: st.header("News Classification") st.write("Classify Setswana news articles into different categories.") # Category mapping categories = { "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang", "crime_law_and_justice": "Bosenyi, molao le bosiamisi", "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso", "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete", "education": "Thuto", "environment": "Tikologo", "health": "Boitekanelo", "politics": "Dipolotiki", "religion_and_belief": "Bodumedi le tumelo", "society": "Setšhaba" } news_examples = [ "Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.", "Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.", "Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.", "Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono." ] news_input = get_input_text("news", news_examples) if st.button("Classify News", key="news_button") and news_input.strip(): with st.spinner("Classifying news..."): try: classifier = load_news_classification_model() results = classifier(news_input) # Process results predictions = {} for pred in results[0]: category_en = pred['label'] category_tn = categories.get(category_en, category_en) predictions[category_tn] = round(pred['score'], 4) # Sort by confidence sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) st.subheader("Classification Results") # Display as progress bars for category, confidence in list(sorted_predictions.items())[:5]: st.write(f"**{category}**") st.progress(confidence) st.write(f"Confidence: {confidence:.1%}") st.write("") # Display full results table with st.expander("View All Categories"): results_df = pd.DataFrame([ {"Category": cat, "Confidence": conf} for cat, conf in sorted_predictions.items() ]) st.dataframe(results_df, use_container_width=True) except Exception as e: st.error(f"Error: {str(e)}") # -------------------- FOOTER -------------------- st.markdown("---") st.markdown(""" ### 📚 Citation ```bibtex @inproceedings{marivate2023puoberta, title = {PuoBERTa: Training and evaluation of a curated language model for Setswana}, author = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai}, year = {2023}, booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science}, url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17}, keywords = {NLP}, preprint_url = {https://arxiv.org/abs/2310.09141}, dataset_url = {https://github.com/dsfsi/PuoBERTa}, software_url = {https://huggingface.co/dsfsi/PuoBERTa} } ``` **Links**: [Paper](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa) """)