vukosi commited on
Commit
8b740b2
·
1 Parent(s): 4948d8f

Updaed to multi demo

Browse files
Files changed (2) hide show
  1. app.py +318 -66
  2. logo_transparent_small.png +0 -0
app.py CHANGED
@@ -1,52 +1,84 @@
1
- # Refactored Streamlit App for Setswana NER using HuggingFace Models
2
-
3
  import streamlit as st
4
- from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
 
 
 
 
 
 
5
  import pandas as pd
6
  import spacy
 
 
7
 
8
  # -------------------- PAGE CONFIG --------------------
9
- st.set_page_config(layout="wide")
 
 
 
 
10
 
11
- # -------------------- UI HEADER --------------------
12
- st.image("logo_transparent_small.png", use_column_width="always")
13
- st.title("Demo for Setswana PuoBERTa NER Model")
 
 
 
 
14
 
15
- # -------------------- MODEL SELECTION --------------------
16
- model_list = ['dsfsi/PuoBERTa-NER']
17
- model_checkpoint = st.sidebar.radio("Select NER Model", model_list)
18
- aggregation_strategy = "simple"
 
 
 
 
19
 
20
- # -------------------- TEXT INPUT --------------------
21
- input_method = st.radio("Select Input Method", ['Example Text', 'Write Text', 'Upload CSV'])
22
 
23
- def get_input_text():
24
- if input_method == 'Example Text':
25
- examples = [
26
- "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00"
27
- ]
28
- return st.selectbox("Example Sentences", examples)
29
- elif input_method == 'Write Text':
30
- return st.text_area("Enter text", height=128)
31
- elif input_method == 'Upload CSV':
32
- uploaded = st.file_uploader("Upload CSV", type="csv")
33
- if uploaded:
34
- df = pd.read_csv(uploaded)
35
- col = st.selectbox("Choose column with text", df.columns)
36
- return "\n".join(df[col].dropna().astype(str).tolist())
37
- return ""
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- input_text = get_input_text()
 
 
 
 
40
 
41
- # -------------------- MODEL LOADING --------------------
42
  @st.cache_resource
43
- def load_ner_pipeline(model_checkpoint, strategy):
44
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
45
- model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
46
- return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy=strategy)
47
 
48
- # -------------------- ENTITY MERGE --------------------
49
  def merge_entities(output):
 
50
  merged = []
51
  for i, ent in enumerate(output):
52
  if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
@@ -56,40 +88,256 @@ def merge_entities(output):
56
  merged.append(ent)
57
  return merged
58
 
59
- # -------------------- RUN NER --------------------
60
- if st.button("Run NER") and input_text.strip():
61
- with st.spinner("Running NER..."):
62
- ner = load_ner_pipeline(model_checkpoint, aggregation_strategy)
63
- output = ner(input_text)
64
- entities = merge_entities(output)
65
-
66
- if entities:
67
- df = pd.DataFrame(entities)[['word','entity_group','score','start','end']]
68
- st.subheader("Recognized Entities")
69
- st.dataframe(df)
70
-
71
- spacy_display = {"text": input_text, "ents": [], "title": None}
72
- for ent in entities:
73
- label = ent["entity_group"]
74
- if label == "PER":
75
- label = "PERSON"
76
- spacy_display["ents"].append({"start": ent["start"], "end": ent["end"], "label": label})
77
-
78
- html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True)
79
- styled_html = f"<style>mark.entity {{ display: inline-block; }}</style><div style='overflow-x:auto;'>{html}</div>"
80
- st.markdown(styled_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  else:
82
- st.info("No entities recognized in the input.")
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # -------------------- AUTHORS, CITATION & FEEDBACK --------------------
85
- st.markdown("""
86
- ---
87
- ### 📚 Authors & Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- **Authors**
90
- Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- **Citation**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  ```bibtex
94
  @inproceedings{marivate2023puoberta,
95
  title = {PuoBERTa: Training and evaluation of a curated language model for Setswana},
@@ -101,4 +349,8 @@ Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanes
101
  preprint_url = {https://arxiv.org/abs/2310.09141},
102
  dataset_url = {https://github.com/dsfsi/PuoBERTa},
103
  software_url = {https://huggingface.co/dsfsi/PuoBERTa}
104
- }""")
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import (
3
+ pipeline,
4
+ AutoModelForTokenClassification,
5
+ AutoTokenizer,
6
+ AutoModelForSequenceClassification,
7
+ AutoModelForMaskedLM
8
+ )
9
  import pandas as pd
10
  import spacy
11
+ import csv
12
+ from io import StringIO
13
 
14
  # -------------------- PAGE CONFIG --------------------
15
+ st.set_page_config(
16
+ page_title="PuoBERTa Multi-Task Demo",
17
+ page_icon="🔤",
18
+ layout="wide"
19
+ )
20
 
21
+ # -------------------- HEADER --------------------
22
+ col1, col2, col3 = st.columns([1, 2, 1])
23
+ with col2:
24
+ try:
25
+ st.image("logo_transparent_small.png", width=300)
26
+ except:
27
+ st.write("🔤 PuoBERTa")
28
 
29
+ st.title("PuoBERTa Multi-Task Demo")
30
+ st.markdown("""
31
+ A comprehensive demo for Setswana language models including:
32
+ - **Mask Filling**: Fill in missing words in sentences
33
+ - **POS Tagging**: Identify parts of speech
34
+ - **Named Entity Recognition**: Extract entities like people, places, organizations
35
+ - **News Classification**: Classify news articles by category
36
+ """)
37
 
38
+ st.markdown("---")
 
39
 
40
+ # -------------------- SIDEBAR --------------------
41
+ st.sidebar.header("Model Information")
42
+ st.sidebar.markdown("""
43
+ **Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai
44
+
45
+ **Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://arxiv.org/abs/2310.09141)
46
+
47
+ **Models Used**:
48
+ - dsfsi/PuoBERTa (Mask Filling)
49
+ - dsfsi/PuoBERTa-POS (POS Tagging)
50
+ - dsfsi/PuoBERTa-NER (Named Entity Recognition)
51
+ - dsfsi/PuoBERTa-News (News Classification)
52
+ """)
53
+
54
+ # -------------------- CACHING FUNCTIONS --------------------
55
+ @st.cache_resource
56
+ def load_mask_filling_model():
57
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa")
58
+ model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa")
59
+ return pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5)
60
+
61
+ @st.cache_resource
62
+ def load_pos_model():
63
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS")
64
+ model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS")
65
+ return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
66
 
67
+ @st.cache_resource
68
+ def load_ner_model():
69
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER")
70
+ model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER")
71
+ return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
72
 
 
73
  @st.cache_resource
74
+ def load_news_classification_model():
75
+ tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News")
76
+ model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News")
77
+ return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
78
 
79
+ # -------------------- UTILITY FUNCTIONS --------------------
80
  def merge_entities(output):
81
+ """Merge consecutive entities of the same type"""
82
  merged = []
83
  for i, ent in enumerate(output):
84
  if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
 
88
  merged.append(ent)
89
  return merged
90
 
91
+ def create_spacy_display(text, entities, task_type="ner"):
92
+ """Create spaCy-style display for entities"""
93
+ spacy_display = {"text": text, "ents": [], "title": None}
94
+
95
+ for ent in entities:
96
+ label = ent["entity_group"]
97
+ if task_type == "ner" and label == "PER":
98
+ label = "PERSON"
99
+ spacy_display["ents"].append({
100
+ "start": ent["start"],
101
+ "end": ent["end"],
102
+ "label": label
103
+ })
104
+
105
+ # Define colors for different entity types
106
+ colors = {
107
+ # POS colors
108
+ "PRON": "#FF9999",
109
+ "VERB": "#99FF99",
110
+ "DET": "#9999FF",
111
+ "PROPN": "#FFFF99",
112
+ "CCONJ": "#FFCC99",
113
+ "PUNCT": "#CCCCCC",
114
+ "NUM": "#FFCCFF",
115
+ "NOUN": "#FFB366",
116
+ "ADJ": "#B366FF",
117
+ "ADP": "#66FFB3",
118
+ # NER colors
119
+ "PERSON": "#85DCDF",
120
+ "PER": "#85DCDF",
121
+ "LOC": "#DF85DC",
122
+ "ORG": "#DCDF85",
123
+ "MISC": "#85ABDF"
124
+ }
125
+
126
+ try:
127
+ html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True,
128
+ options={"colors": colors})
129
+ styled_html = f"""
130
+ <style>mark.entity {{ display: inline-block; }}</style>
131
+ <div style='overflow-x:auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;'>
132
+ {html}
133
+ </div>
134
+ """
135
+ return styled_html
136
+ except:
137
+ return "<p>Error rendering visualization</p>"
138
+
139
+ def get_input_text(tab_name, examples):
140
+ """Get input text based on selected method"""
141
+ input_method = st.radio(
142
+ "Select Input Method",
143
+ ['Example Text', 'Write Text', 'Upload File'],
144
+ key=f"{tab_name}_input_method"
145
+ )
146
+
147
+ if input_method == 'Example Text':
148
+ return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples")
149
+ elif input_method == 'Write Text':
150
+ return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input")
151
+ elif input_method == 'Upload File':
152
+ uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file")
153
+ if uploaded:
154
+ if uploaded.name.endswith('.csv'):
155
+ df = pd.read_csv(uploaded)
156
+ st.write("CSV Preview:", df.head())
157
+ col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col")
158
+ return "\n".join(df[col].dropna().astype(str).tolist())
159
+ else:
160
+ return str(uploaded.read(), "utf-8")
161
+ return ""
162
+
163
+ # -------------------- TABS --------------------
164
+ tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"])
165
+
166
+ # -------------------- MASK FILLING TAB --------------------
167
+ with tab1:
168
+ st.header("Mask Filling")
169
+ st.write("Fill in the blanks in Setswana sentences using `[MASK]` token.")
170
+
171
+ mask_examples = [
172
+ "Ke rata go [MASK] dijo tsa Batswana.",
173
+ "Botswana ke naga e e [MASK] mo Afrika Borwa.",
174
+ "Bana ba [MASK] sekolo ka Mosupologo.",
175
+ "Re tshwanetse go [MASK] tikologo ya rona."
176
+ ]
177
+
178
+ mask_input = get_input_text("mask", mask_examples)
179
+
180
+ if st.button("Fill Masks", key="mask_button") and mask_input.strip():
181
+ if "[MASK]" not in mask_input:
182
+ st.warning("Please include [MASK] token in your text.")
183
  else:
184
+ with st.spinner("Filling masks..."):
185
+ try:
186
+ mask_filler = load_mask_filling_model()
187
+ results = mask_filler(mask_input)
188
+
189
+ st.subheader("Predictions")
190
+ for i, result in enumerate(results, 1):
191
+ confidence = result['score'] * 100
192
+ st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)")
193
+
194
+ except Exception as e:
195
+ st.error(f"Error: {str(e)}")
196
 
197
+ # -------------------- POS TAGGING TAB --------------------
198
+ with tab2:
199
+ st.header("Parts of Speech Tagging")
200
+ st.write("Identify grammatical parts of speech in Setswana text.")
201
+
202
+ pos_examples = [
203
+ "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00",
204
+ "Batho ba le bantsi ba rata go bala dikgang tsa Setswana.",
205
+ "Ke ithutile Setswana kwa sekolong sa me.",
206
+ "Dikgomo di ja bojang mo tshimong."
207
+ ]
208
+
209
+ pos_input = get_input_text("pos", pos_examples)
210
+
211
+ if st.button("Run POS Tagging", key="pos_button") and pos_input.strip():
212
+ with st.spinner("Running POS tagging..."):
213
+ try:
214
+ pos_tagger = load_pos_model()
215
+ output = pos_tagger(pos_input)
216
+ entities = merge_entities(output)
217
+
218
+ if entities:
219
+ # Display results table
220
+ df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
221
+ df['score'] = df['score'].round(4)
222
+ st.subheader("POS Tags")
223
+ st.dataframe(df, use_container_width=True)
224
+
225
+ # Visual display
226
+ st.subheader("Visual Display")
227
+ html = create_spacy_display(pos_input, entities, "pos")
228
+ st.markdown(html, unsafe_allow_html=True)
229
+ else:
230
+ st.info("No POS tags identified.")
231
+
232
+ except Exception as e:
233
+ st.error(f"Error: {str(e)}")
234
 
235
+ # -------------------- NER TAB --------------------
236
+ with tab3:
237
+ st.header("Named Entity Recognition")
238
+ st.write("Extract named entities like people, places, and organizations from Setswana text.")
239
+
240
+ ner_examples = [
241
+ "Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.",
242
+ "Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.",
243
+ "Botswana Democratic Party e ne ya kopana le African National Congress.",
244
+ "Bank of Botswana e mo Gaborone e laola economy ya naga."
245
+ ]
246
+
247
+ ner_input = get_input_text("ner", ner_examples)
248
+
249
+ if st.button("Run NER", key="ner_button") and ner_input.strip():
250
+ with st.spinner("Running NER..."):
251
+ try:
252
+ ner_pipeline = load_ner_model()
253
+ output = ner_pipeline(ner_input)
254
+ entities = merge_entities(output)
255
+
256
+ if entities:
257
+ # Display results table
258
+ df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']]
259
+ df['score'] = df['score'].round(4)
260
+ st.subheader("Named Entities")
261
+ st.dataframe(df, use_container_width=True)
262
+
263
+ # Visual display
264
+ st.subheader("Visual Display")
265
+ html = create_spacy_display(ner_input, entities, "ner")
266
+ st.markdown(html, unsafe_allow_html=True)
267
+ else:
268
+ st.info("No named entities found.")
269
+
270
+ except Exception as e:
271
+ st.error(f"Error: {str(e)}")
272
 
273
+ # -------------------- NEWS CLASSIFICATION TAB --------------------
274
+ with tab4:
275
+ st.header("News Classification")
276
+ st.write("Classify Setswana news articles into different categories.")
277
+
278
+ # Category mapping
279
+ categories = {
280
+ "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang",
281
+ "crime_law_and_justice": "Bosenyi, molao le bosiamisi",
282
+ "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso",
283
+ "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete",
284
+ "education": "Thuto",
285
+ "environment": "Tikologo",
286
+ "health": "Boitekanelo",
287
+ "politics": "Dipolotiki",
288
+ "religion_and_belief": "Bodumedi le tumelo",
289
+ "society": "Setšhaba"
290
+ }
291
+
292
+ news_examples = [
293
+ "Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.",
294
+ "Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.",
295
+ "Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.",
296
+ "Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono."
297
+ ]
298
+
299
+ news_input = get_input_text("news", news_examples)
300
+
301
+ if st.button("Classify News", key="news_button") and news_input.strip():
302
+ with st.spinner("Classifying news..."):
303
+ try:
304
+ classifier = load_news_classification_model()
305
+ results = classifier(news_input)
306
+
307
+ # Process results
308
+ predictions = {}
309
+ for pred in results[0]:
310
+ category_en = pred['label']
311
+ category_tn = categories.get(category_en, category_en)
312
+ predictions[category_tn] = round(pred['score'], 4)
313
+
314
+ # Sort by confidence
315
+ sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True))
316
+
317
+ st.subheader("Classification Results")
318
+
319
+ # Display as progress bars
320
+ for category, confidence in list(sorted_predictions.items())[:5]:
321
+ st.write(f"**{category}**")
322
+ st.progress(confidence)
323
+ st.write(f"Confidence: {confidence:.1%}")
324
+ st.write("")
325
+
326
+ # Display full results table
327
+ with st.expander("View All Categories"):
328
+ results_df = pd.DataFrame([
329
+ {"Category": cat, "Confidence": conf}
330
+ for cat, conf in sorted_predictions.items()
331
+ ])
332
+ st.dataframe(results_df, use_container_width=True)
333
+
334
+ except Exception as e:
335
+ st.error(f"Error: {str(e)}")
336
+
337
+ # -------------------- FOOTER --------------------
338
+ st.markdown("---")
339
+ st.markdown("""
340
+ ### 📚 Citation
341
  ```bibtex
342
  @inproceedings{marivate2023puoberta,
343
  title = {PuoBERTa: Training and evaluation of a curated language model for Setswana},
 
349
  preprint_url = {https://arxiv.org/abs/2310.09141},
350
  dataset_url = {https://github.com/dsfsi/PuoBERTa},
351
  software_url = {https://huggingface.co/dsfsi/PuoBERTa}
352
+ }
353
+ ```
354
+
355
+ **Links**: [Paper](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa)
356
+ """)
logo_transparent_small.png CHANGED