File size: 16,442 Bytes
e6bfe5c
8b740b2
 
 
 
 
 
 
e6bfe5c
d25abcf
8b740b2
 
e6bfe5c
295300a
8b740b2
 
 
 
 
1b711d9
8b740b2
 
 
 
 
 
 
125e609
8b740b2
 
 
 
 
 
 
 
295300a
8b740b2
295300a
8b740b2
 
 
 
 
c4a32ef
ca52b6a
 
8b740b2
 
bc348b7
 
 
 
8b740b2
 
 
 
 
3096ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
8b740b2
 
 
 
 
 
295300a
8b740b2
 
 
 
 
5b3d11c
f193a60
8b740b2
 
 
 
295300a
8b740b2
3096ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295300a
8b740b2
295300a
 
 
 
 
4874aa0
295300a
 
 
8b740b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3096ba9
8b740b2
 
3096ba9
 
 
 
8b740b2
 
 
 
 
3096ba9
 
 
 
 
 
c558c48
8b740b2
 
 
3096ba9
 
 
8b740b2
 
 
 
 
 
 
 
3096ba9
 
 
 
 
 
 
4948d8f
8b740b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4948d8f
8b740b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4948d8f
8b740b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4948d8f
 
 
 
 
 
 
 
 
 
 
8b740b2
 
 
c4a32ef
8b740b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
import streamlit as st
from transformers import (
    pipeline, 
    AutoModelForTokenClassification, 
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForMaskedLM
)
import pandas as pd
import spacy
import csv
from io import StringIO

# -------------------- PAGE CONFIG --------------------
st.set_page_config(
    page_title="PuoBERTa Multi-Task Demo",
    page_icon="🔤",
    layout="wide"
)

# -------------------- HEADER --------------------
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
    try:
        st.image("logo_transparent_small.png", width=300)
    except:
        st.write("🔤 PuoBERTa")

st.title("PuoBERTa Multi-Task Demo")
st.markdown("""
A comprehensive demo for Setswana language models including:
- **Mask Filling**: Fill in missing words in sentences
- **POS Tagging**: Identify parts of speech
- **Named Entity Recognition**: Extract entities like people, places, organizations
- **News Classification**: Classify news articles by category
""")

st.markdown("---")

# -------------------- SIDEBAR --------------------
st.sidebar.header("Model Information")
st.sidebar.markdown("""
**Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai

**Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141)

**Huggingface Space Creators**: Vukosi Marivate, Zion Van Wyk, Unarine Netshifhefhe, Thapelo Sindane

**Models Used**:
- [dsfsi/PuoBERTa (Mask Filling - Pretrained Model)](https://huggingface.co/dsfsi/PuoBERTa)
- [dsfsi/PuoBERTa-POS (POS Tagging)](https://huggingface.co/dsfsi/PuoBERTa-POS)
- [dsfsi/PuoBERTa-NER (Named Entity Recognition)](https://huggingface.co/dsfsi/PuoBERTa-NER)
- [dsfsi/PuoBERTa-News (News Classification)](https://huggingface.co/dsfsi/PuoBERTa-News)
""")

# -------------------- CACHING FUNCTIONS --------------------
@st.cache_resource
def load_mask_filling_model():
    try:
        tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
        model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
        
        # Create pipeline and verify mask token
        pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
        
        # Debug: print mask token for verification
        print(f"Mask token being used: {tokenizer.mask_token}")
        
        return pipe
    except Exception as e:
        st.error(f"Failed to load mask filling model: {str(e)}")
        return None

@st.cache_resource
def load_pos_model():
    tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS")
    model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS")
    return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")

@st.cache_resource
def load_ner_model():
    tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER")
    model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER")
    return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")

@st.cache_resource
def load_news_classification_model():
    tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
    model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
    return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)

# -------------------- UTILITY FUNCTIONS --------------------

def get_correct_mask_token(text, tokenizer):
    """Get the correct mask token format for the given tokenizer"""
    mask_token = tokenizer.mask_token
    
    # Replace common mask token formats with the correct one
    text = text.replace("[MASK]", mask_token)
    text = text.replace("<mask>", mask_token)
    text = text.replace("&lt;mask&gt;", mask_token)
    
    return text

# Then in your mask filling section, use:
# corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
# results = mask_filler(corrected_input)


def merge_entities(output):
    """Merge consecutive entities of the same type"""
    merged = []
    for i, ent in enumerate(output):
        if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
            merged[-1]["word"] += ent["word"]
            merged[-1]["end"] = ent["end"]
        else:
            merged.append(ent)
    return merged

def create_spacy_display(text, entities, task_type="ner"):
    """Create spaCy-style display for entities"""
    spacy_display = {"text": text, "ents": [], "title": None}
    
    for ent in entities:
        label = ent["entity_group"]
        if task_type == "ner" and label == "PER":
            label = "PERSON"
        spacy_display["ents"].append({
            "start": ent["start"], 
            "end": ent["end"], 
            "label": label
        })
    
    # Define colors for different entity types
    colors = {
        # POS colors
        "PRON": "#FF9999",
        "VERB": "#99FF99", 
        "DET": "#9999FF",
        "PROPN": "#FFFF99",
        "CCONJ": "#FFCC99",
        "PUNCT": "#CCCCCC",
        "NUM": "#FFCCFF",
        "NOUN": "#FFB366",
        "ADJ": "#B366FF",
        "ADP": "#66FFB3",
        # NER colors
        "PERSON": "#85DCDF",
        "PER": "#85DCDF", 
        "LOC": "#DF85DC",
        "ORG": "#DCDF85",
        "MISC": "#85ABDF"
    }
    
    try:
        html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True, 
                                   options={"colors": colors})
        styled_html = f"""
        <style>mark.entity {{ display: inline-block; }}</style>
        <div style='overflow-x:auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;'>
            {html}
        </div>
        """
        return styled_html
    except:
        return "<p>Error rendering visualization</p>"

def get_input_text(tab_name, examples):
    """Get input text based on selected method"""
    input_method = st.radio(
        "Select Input Method", 
        ['Example Text', 'Write Text', 'Upload File'],
        key=f"{tab_name}_input_method"
    )
    
    if input_method == 'Example Text':
        return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples")
    elif input_method == 'Write Text':
        return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input")
    elif input_method == 'Upload File':
        uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file")
        if uploaded:
            if uploaded.name.endswith('.csv'):
                df = pd.read_csv(uploaded)
                st.write("CSV Preview:", df.head())
                col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col")
                return "\n".join(df[col].dropna().astype(str).tolist())
            else:
                return str(uploaded.read(), "utf-8")
    return ""

# -------------------- TABS --------------------
tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"])

# -------------------- MASK FILLING TAB --------------------
with tab1:
    st.header("Mask Filling")
    st.write("Fill in the blanks in Setswana sentences using `<mask>` token.")
    
    mask_examples = [
        "Ke rata go <mask> dijo tsa Batswana.",
        "Botswana ke naga e e <mask> mo Afrika Borwa.",
        "Bana ba <mask> sekolo ka Mosupologo.",
        "Re tshwanetse go <mask> tikologo ya rona."
    ]
    
    mask_input = get_input_text("mask", mask_examples)
    
    if st.button("Fill Masks", key="mask_button") and mask_input.strip():
        # Check for both mask formats and convert if needed
        if "[MASK]" in mask_input:
            mask_input = mask_input.replace("[MASK]", "<mask>")
            st.info("Converted [MASK] to <mask> format")
        elif "<mask>" not in mask_input:
            st.warning("Please include <mask> token in your text.")
        else:
            with st.spinner("Filling masks..."):
                try:
                    mask_filler = load_mask_filling_model()
                    corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer)
                    results = mask_filler(corrected_input)
                    # results = mask_filler(mask_input)
                    
                    st.subheader("Predictions")
                    for i, result in enumerate(results, 1):
                        confidence = result['score'] * 100
                        st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)")
                        
                except Exception as e:
                    st.error(f"Error: {str(e)}")
                    # Debug information
                    st.info(f"Input text: {mask_input}")
                    try:
                        mask_filler = load_mask_filling_model()
                        st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}")
                    except:
                        pass

# -------------------- POS TAGGING TAB --------------------
with tab2:
    st.header("Parts of Speech Tagging")
    st.write("Identify grammatical parts of speech in Setswana text.")
    
    pos_examples = [
        "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00",
        "Batho ba le bantsi ba rata go bala dikgang tsa Setswana.",
        "Ke ithutile Setswana kwa sekolong sa me.",
        "Dikgomo di ja bojang mo tshimong."
    ]
    
    pos_input = get_input_text("pos", pos_examples)
    
    if st.button("Run POS Tagging", key="pos_button") and pos_input.strip():
        with st.spinner("Running POS tagging..."):
            try:
                pos_tagger = load_pos_model()
                output = pos_tagger(pos_input)
                entities = merge_entities(output)
                
                if entities:
                    # Display results table
                    df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
                    df['score'] = df['score'].round(4)
                    st.subheader("POS Tags")
                    st.dataframe(df, use_container_width=True)
                    
                    # Visual display
                    st.subheader("Visual Display")
                    html = create_spacy_display(pos_input, entities, "pos")
                    st.markdown(html, unsafe_allow_html=True)
                else:
                    st.info("No POS tags identified.")
                    
            except Exception as e:
                st.error(f"Error: {str(e)}")

# -------------------- NER TAB --------------------
with tab3:
    st.header("Named Entity Recognition")
    st.write("Extract named entities like people, places, and organizations from Setswana text.")
    
    ner_examples = [
        "Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.",
        "Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.",
        "Botswana Democratic Party e ne ya kopana le African National Congress.",
        "Bank of Botswana e mo Gaborone e laola economy ya naga."
    ]
    
    ner_input = get_input_text("ner", ner_examples)
    
    if st.button("Run NER", key="ner_button") and ner_input.strip():
        with st.spinner("Running NER..."):
            try:
                ner_pipeline = load_ner_model()
                output = ner_pipeline(ner_input)
                entities = merge_entities(output)
                
                if entities:
                    # Display results table
                    df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
                    df['score'] = df['score'].round(4)
                    st.subheader("Named Entities")
                    st.dataframe(df, use_container_width=True)
                    
                    # Visual display
                    st.subheader("Visual Display")
                    html = create_spacy_display(ner_input, entities, "ner")
                    st.markdown(html, unsafe_allow_html=True)
                else:
                    st.info("No named entities found.")
                    
            except Exception as e:
                st.error(f"Error: {str(e)}")

# -------------------- NEWS CLASSIFICATION TAB --------------------
with tab4:
    st.header("News Classification")
    st.write("Classify Setswana news articles into different categories.")
    
    # Category mapping
    categories = {
        "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
        "crime_law_and_justice": "Bosenyi, molao le bosiamisi", 
        "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
        "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
        "education": "Thuto",
        "environment": "Tikologo", 
        "health": "Boitekanelo",
        "politics": "Dipolotiki",
        "religion_and_belief": "Bodumedi le tumelo",
        "society": "Setšhaba"
    }
    
    news_examples = [
        "Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.",
        "Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.",
        "Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.",
        "Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono."
    ]
    
    news_input = get_input_text("news", news_examples)
    
    if st.button("Classify News", key="news_button") and news_input.strip():
        with st.spinner("Classifying news..."):
            try:
                classifier = load_news_classification_model()
                results = classifier(news_input)
                
                # Process results
                predictions = {}
                for pred in results[0]:
                    category_en = pred['label']
                    category_tn = categories.get(category_en, category_en)
                    predictions[category_tn] = round(pred['score'], 4)
                
                # Sort by confidence
                sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
                
                st.subheader("Classification Results")
                
                # Display as progress bars
                for category, confidence in list(sorted_predictions.items())[:5]:
                    st.write(f"**{category}**")
                    st.progress(confidence)
                    st.write(f"Confidence: {confidence:.1%}")
                    st.write("")
                    
                # Display full results table
                with st.expander("View All Categories"):
                    results_df = pd.DataFrame([
                        {"Category": cat, "Confidence": conf} 
                        for cat, conf in sorted_predictions.items()
                    ])
                    st.dataframe(results_df, use_container_width=True)
                    
            except Exception as e:
                st.error(f"Error: {str(e)}")

# -------------------- FOOTER --------------------
st.markdown("---")
st.markdown("""
### 📚 Citation
```bibtex
@inproceedings{marivate2023puoberta,
  title   = {PuoBERTa: Training and evaluation of a curated language model for Setswana},
  author  = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai},
  year    = {2023},
  booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science},
  url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17},
  keywords = {NLP},
  preprint_url = {https://arxiv.org/abs/2310.09141},
  dataset_url = {https://github.com/dsfsi/PuoBERTa},
  software_url = {https://huggingface.co/dsfsi/PuoBERTa}
}
```

**Links**: [Paper](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa)
""")