vukosi commited on
Commit
295300a
·
verified ·
1 Parent(s): 2157d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -123
app.py CHANGED
@@ -1,136 +1,84 @@
1
- # Turkish NER Demo for Various Models
2
-
3
- from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, DebertaV2Tokenizer, DebertaV2Model
4
- import sentencepiece
5
  import streamlit as st
 
6
  import pandas as pd
7
  import spacy
8
 
 
9
  st.set_page_config(layout="wide")
10
 
11
- example_list = [
12
- "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00"
13
- ]
14
-
15
- #logo
16
- img1, img2, img3 = st.columns(3)
17
- with img2:
18
- with st.container(border=False):
19
- st.image("logo_transparent_small.png")
20
-
21
-
22
- st.title("Demo for Sestwana NER Models")
23
 
24
- st.write("A Setswana Langage Model Finetuned on MasakhaNER-2 for Named Entity Recognition")
25
- st.write("Co authors : Vukosi Marivate (@vukosi), Moseli Mots'Oehli (@MoseliMotsoehli) , Valencia Wagner, Richard Lastrucci and Isheanesu Dzingirai")
26
- st.write("Link to model: https://arxiv.org/abs/2310.09141")
27
 
 
28
  model_list = ['dsfsi/PuoBERTa-NER']
29
-
30
- st.sidebar.header("Select NER Model")
31
- model_checkpoint = st.sidebar.radio("", model_list)
32
-
33
-
34
- if model_checkpoint == "akdeniz27/xlm-roberta-base-turkish-ner":
35
- aggregation = "simple"
36
- elif model_checkpoint == "dsfsi/PuoBERTa-NER":
37
- aggregation = "simple"
38
- elif model_checkpoint == "xlm-roberta-large-finetuned-conll03-english" or model_checkpoint == "asahi417/tner-xlm-roberta-base-ontonotes5":
39
- aggregation = "simple"
40
- st.sidebar.write("")
41
- st.sidebar.write("The selected NER model is included just to show the zero-shot transfer learning capability of XLM-Roberta pretrained language model.")
42
- else:
43
- aggregation = "first"
44
-
45
- st.subheader("Select Text Input Method")
46
- input_method = st.radio("", ('Select from Examples', 'Write or Paste New Text','Upload CSV File'))
47
- if input_method == 'Select from Examples':
48
- selected_text = st.selectbox('Select Text from List', example_list, index=0, key=1)
49
- st.subheader("Text to Run")
50
- input_text = st.text_area("Selected Text", selected_text, height=128, max_chars=None, key=2)
51
- elif input_method == "Write or Paste New Text":
52
- st.subheader("Text to Run")
53
- input_text = st.text_area('Write or Paste Text Below', value="", height=128, max_chars=None, key=2)
54
- elif input_method == "Upload CSV File":
55
- st.subheader("Upload CSV File")
56
- uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
57
-
58
- if uploaded_file is not None:
59
- df_csv = pd.read_csv(uploaded_file)
60
- st.write(df_csv)
61
- sentences = []
62
- for index, row in df_csv.iterrows():
63
- for col in df_csv.columns:
64
- # Add each sentence from the row and columns into the list
65
- sentence = row[col]
66
- if pd.notna(sentence): # Ensure it is not empty or NaN
67
- sentences.append(sentence)
68
-
69
- text_column = st.selectbox("Select the column containing text", sentences)
70
- input_text = text_column
71
-
72
-
73
- @st.cache_resource
74
- def setModel(model_checkpoint, aggregation):
75
- tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER")
76
- model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER")
77
- return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
78
 
79
  @st.cache_resource
80
- def get_html(html: str):
81
- WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
82
- html = html.replace("\n", " ")
83
- return WRAPPER.format(html)
84
-
85
- @st.cache_resource
86
- def entity_comb(output):
87
- output_comb = []
88
- for ind, entity in enumerate(output):
89
- if ind == 0:
90
- output_comb.append(entity)
91
- elif output[ind]["start"] == output[ind-1]["end"] and output[ind]["entity_group"] == output[ind-1]["entity_group"]:
92
- output_comb[-1]["word"] = output_comb[-1]["word"] + output[ind]["word"]
93
- output_comb[-1]["end"] = output[ind]["end"]
94
  else:
95
- output_comb.append(entity)
96
- return output_comb
97
-
98
- Run_Button = st.button("Run", key=None)
99
-
100
- if Run_Button and input_text != "":
101
-
102
- ner_pipeline = setModel(model_checkpoint, aggregation)
103
- output = ner_pipeline(input_text)
104
-
105
- output_comb = entity_comb(output)
106
-
107
- df = pd.DataFrame.from_dict(output_comb)
108
- cols_to_keep = ['word','entity_group','score','start','end']
109
- df_final = df[cols_to_keep]
110
-
111
- st.subheader("Recognized Entities")
112
- st.dataframe(df_final)
113
-
114
- st.subheader("Spacy Style Display")
115
- spacy_display = {}
116
- spacy_display["ents"] = []
117
- spacy_display["text"] = input_text
118
- spacy_display["title"] = None
119
-
120
- for entity in output_comb:
121
- spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
122
-
123
- tner_entity_list = ["person", "group", "facility", "organization", "geopolitical area", "location", "product", "event", "work of art", "law", "language", "date", "time", "percent", "money", "quantity", "ordinal number", "cardinal number"]
124
- spacy_entity_list = ["PERSON", "NORP", "FAC", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL", "MISC"]
125
-
126
- for ent in spacy_display["ents"]:
127
- if model_checkpoint == "asahi417/tner-xlm-roberta-base-ontonotes5":
128
- ent["label"] = spacy_entity_list[tner_entity_list.index(ent["label"])]
129
  else:
130
- if ent["label"] == "PER": ent["label"] = "PERSON"
131
-
132
- # colors = {'PER': '#85DCDF', 'LOC': '#DF85DC', 'ORG': '#DCDF85', 'MISC': '#85ABDF',}
133
- html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True, options={"ents": spacy_entity_list}) # , "colors": colors})
134
- style = "<style>mark.entity { display: inline-block }</style>"
135
- st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
136
-
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
3
  import pandas as pd
4
  import spacy
5
 
6
+ # -------------------- PAGE CONFIG --------------------
7
  st.set_page_config(layout="wide")
8
 
9
+ # -------------------- UI HEADER --------------------
10
+ st.image("logo_transparent_small.png", use_column_width="always")
11
+ st.title("Demo for Setswana NER Models")
12
+ st.markdown("""
13
+ A Setswana Language Model fine-tuned on MasakhaNER-2 for Named Entity Recognition.
 
 
 
 
 
 
 
14
 
15
+ **Co-authors**: Vukosi Marivate (@vukosi), Moseli Mots'Oehli (@MoseliMotsoehli), Valencia Wagner, Richard Lastrucci, and Isheanesu Dzingirai
16
+ **Model link**: [arXiv:2310.09141](https://arxiv.org/abs/2310.09141)
17
+ """)
18
 
19
+ # -------------------- MODEL SELECTION --------------------
20
  model_list = ['dsfsi/PuoBERTa-NER']
21
+ model_checkpoint = st.sidebar.radio("Select NER Model", model_list)
22
+ aggregation_strategy = "simple"
23
+
24
+ # -------------------- TEXT INPUT --------------------
25
+ input_method = st.radio("Select Input Method", ['Example Text', 'Write Text', 'Upload CSV'])
26
+
27
+ def get_input_text():
28
+ if input_method == 'Example Text':
29
+ examples = [
30
+ "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00"
31
+ ]
32
+ return st.selectbox("Example Sentences", examples)
33
+ elif input_method == 'Write Text':
34
+ return st.text_area("Enter text", height=128)
35
+ elif input_method == 'Upload CSV':
36
+ uploaded = st.file_uploader("Upload CSV", type="csv")
37
+ if uploaded:
38
+ df = pd.read_csv(uploaded)
39
+ col = st.selectbox("Choose column with text", df.columns)
40
+ return "\n".join(df[col].dropna().astype(str).tolist())
41
+ return ""
42
+
43
+ input_text = get_input_text()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  @st.cache_resource
46
+ def load_ner_pipeline(model_checkpoint, strategy):
47
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
48
+ model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
49
+ return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy=strategy)
50
+
51
+ def merge_entities(output):
52
+ merged = []
53
+ for i, ent in enumerate(output):
54
+ if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
55
+ merged[-1]["word"] += ent["word"]
56
+ merged[-1]["end"] = ent["end"]
 
 
 
57
  else:
58
+ merged.append(ent)
59
+ return merged
60
+
61
+ if st.button("Run NER") and input_text.strip():
62
+ with st.spinner("Running NER..."):
63
+ ner = load_ner_pipeline(model_checkpoint, aggregation_strategy)
64
+ output = ner(input_text)
65
+ entities = merge_entities(output)
66
+
67
+ if entities:
68
+ df = pd.DataFrame(entities)[['word','entity_group','score','start','end']]
69
+ st.subheader("Recognized Entities")
70
+ st.dataframe(df)
71
+
72
+ # -------------------- SPACY STYLE VISUAL --------------------
73
+ spacy_display = {"text": input_text, "ents": [], "title": None}
74
+ for ent in entities:
75
+ label = ent["entity_group"]
76
+ if label == "PER":
77
+ label = "PERSON"
78
+ spacy_display["ents"].append({"start": ent["start"], "end": ent["end"], "label": label})
79
+
80
+ html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True)
81
+ styled_html = f"<style>mark.entity {{ display: inline-block; }}</style><div style='overflow-x:auto;'>{html}</div>"
82
+ st.markdown(styled_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
83
  else:
84
+ st.info("No entities recognized in the input.")