Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,136 +1,84 @@
|
|
1 |
-
# Turkish NER Demo for Various Models
|
2 |
-
|
3 |
-
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, DebertaV2Tokenizer, DebertaV2Model
|
4 |
-
import sentencepiece
|
5 |
import streamlit as st
|
|
|
6 |
import pandas as pd
|
7 |
import spacy
|
8 |
|
|
|
9 |
st.set_page_config(layout="wide")
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
img1, img2, img3 = st.columns(3)
|
17 |
-
with img2:
|
18 |
-
with st.container(border=False):
|
19 |
-
st.image("logo_transparent_small.png")
|
20 |
-
|
21 |
-
|
22 |
-
st.title("Demo for Sestwana NER Models")
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
|
|
|
28 |
model_list = ['dsfsi/PuoBERTa-NER']
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
st.subheader("Text to Run")
|
53 |
-
input_text = st.text_area('Write or Paste Text Below', value="", height=128, max_chars=None, key=2)
|
54 |
-
elif input_method == "Upload CSV File":
|
55 |
-
st.subheader("Upload CSV File")
|
56 |
-
uploaded_file = st.file_uploader("Choose a CSV file", type="csv")
|
57 |
-
|
58 |
-
if uploaded_file is not None:
|
59 |
-
df_csv = pd.read_csv(uploaded_file)
|
60 |
-
st.write(df_csv)
|
61 |
-
sentences = []
|
62 |
-
for index, row in df_csv.iterrows():
|
63 |
-
for col in df_csv.columns:
|
64 |
-
# Add each sentence from the row and columns into the list
|
65 |
-
sentence = row[col]
|
66 |
-
if pd.notna(sentence): # Ensure it is not empty or NaN
|
67 |
-
sentences.append(sentence)
|
68 |
-
|
69 |
-
text_column = st.selectbox("Select the column containing text", sentences)
|
70 |
-
input_text = text_column
|
71 |
-
|
72 |
-
|
73 |
-
@st.cache_resource
|
74 |
-
def setModel(model_checkpoint, aggregation):
|
75 |
-
tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER")
|
76 |
-
model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER")
|
77 |
-
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
|
78 |
|
79 |
@st.cache_resource
|
80 |
-
def
|
81 |
-
|
82 |
-
|
83 |
-
return
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
elif output[ind]["start"] == output[ind-1]["end"] and output[ind]["entity_group"] == output[ind-1]["entity_group"]:
|
92 |
-
output_comb[-1]["word"] = output_comb[-1]["word"] + output[ind]["word"]
|
93 |
-
output_comb[-1]["end"] = output[ind]["end"]
|
94 |
else:
|
95 |
-
|
96 |
-
return
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
for entity in output_comb:
|
121 |
-
spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
|
122 |
-
|
123 |
-
tner_entity_list = ["person", "group", "facility", "organization", "geopolitical area", "location", "product", "event", "work of art", "law", "language", "date", "time", "percent", "money", "quantity", "ordinal number", "cardinal number"]
|
124 |
-
spacy_entity_list = ["PERSON", "NORP", "FAC", "ORG", "GPE", "LOC", "PRODUCT", "EVENT", "WORK_OF_ART", "LAW", "LANGUAGE", "DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL", "MISC"]
|
125 |
-
|
126 |
-
for ent in spacy_display["ents"]:
|
127 |
-
if model_checkpoint == "asahi417/tner-xlm-roberta-base-ontonotes5":
|
128 |
-
ent["label"] = spacy_entity_list[tner_entity_list.index(ent["label"])]
|
129 |
else:
|
130 |
-
|
131 |
-
|
132 |
-
# colors = {'PER': '#85DCDF', 'LOC': '#DF85DC', 'ORG': '#DCDF85', 'MISC': '#85ABDF',}
|
133 |
-
html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True, options={"ents": spacy_entity_list}) # , "colors": colors})
|
134 |
-
style = "<style>mark.entity { display: inline-block }</style>"
|
135 |
-
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
|
136 |
-
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
|
3 |
import pandas as pd
|
4 |
import spacy
|
5 |
|
6 |
+
# -------------------- PAGE CONFIG --------------------
|
7 |
st.set_page_config(layout="wide")
|
8 |
|
9 |
+
# -------------------- UI HEADER --------------------
|
10 |
+
st.image("logo_transparent_small.png", use_column_width="always")
|
11 |
+
st.title("Demo for Setswana NER Models")
|
12 |
+
st.markdown("""
|
13 |
+
A Setswana Language Model fine-tuned on MasakhaNER-2 for Named Entity Recognition.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
**Co-authors**: Vukosi Marivate (@vukosi), Moseli Mots'Oehli (@MoseliMotsoehli), Valencia Wagner, Richard Lastrucci, and Isheanesu Dzingirai
|
16 |
+
**Model link**: [arXiv:2310.09141](https://arxiv.org/abs/2310.09141)
|
17 |
+
""")
|
18 |
|
19 |
+
# -------------------- MODEL SELECTION --------------------
|
20 |
model_list = ['dsfsi/PuoBERTa-NER']
|
21 |
+
model_checkpoint = st.sidebar.radio("Select NER Model", model_list)
|
22 |
+
aggregation_strategy = "simple"
|
23 |
+
|
24 |
+
# -------------------- TEXT INPUT --------------------
|
25 |
+
input_method = st.radio("Select Input Method", ['Example Text', 'Write Text', 'Upload CSV'])
|
26 |
+
|
27 |
+
def get_input_text():
|
28 |
+
if input_method == 'Example Text':
|
29 |
+
examples = [
|
30 |
+
"Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00"
|
31 |
+
]
|
32 |
+
return st.selectbox("Example Sentences", examples)
|
33 |
+
elif input_method == 'Write Text':
|
34 |
+
return st.text_area("Enter text", height=128)
|
35 |
+
elif input_method == 'Upload CSV':
|
36 |
+
uploaded = st.file_uploader("Upload CSV", type="csv")
|
37 |
+
if uploaded:
|
38 |
+
df = pd.read_csv(uploaded)
|
39 |
+
col = st.selectbox("Choose column with text", df.columns)
|
40 |
+
return "\n".join(df[col].dropna().astype(str).tolist())
|
41 |
+
return ""
|
42 |
+
|
43 |
+
input_text = get_input_text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
@st.cache_resource
|
46 |
+
def load_ner_pipeline(model_checkpoint, strategy):
|
47 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
48 |
+
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
|
49 |
+
return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy=strategy)
|
50 |
+
|
51 |
+
def merge_entities(output):
|
52 |
+
merged = []
|
53 |
+
for i, ent in enumerate(output):
|
54 |
+
if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]:
|
55 |
+
merged[-1]["word"] += ent["word"]
|
56 |
+
merged[-1]["end"] = ent["end"]
|
|
|
|
|
|
|
57 |
else:
|
58 |
+
merged.append(ent)
|
59 |
+
return merged
|
60 |
+
|
61 |
+
if st.button("Run NER") and input_text.strip():
|
62 |
+
with st.spinner("Running NER..."):
|
63 |
+
ner = load_ner_pipeline(model_checkpoint, aggregation_strategy)
|
64 |
+
output = ner(input_text)
|
65 |
+
entities = merge_entities(output)
|
66 |
+
|
67 |
+
if entities:
|
68 |
+
df = pd.DataFrame(entities)[['word','entity_group','score','start','end']]
|
69 |
+
st.subheader("Recognized Entities")
|
70 |
+
st.dataframe(df)
|
71 |
+
|
72 |
+
# -------------------- SPACY STYLE VISUAL --------------------
|
73 |
+
spacy_display = {"text": input_text, "ents": [], "title": None}
|
74 |
+
for ent in entities:
|
75 |
+
label = ent["entity_group"]
|
76 |
+
if label == "PER":
|
77 |
+
label = "PERSON"
|
78 |
+
spacy_display["ents"].append({"start": ent["start"], "end": ent["end"], "label": label})
|
79 |
+
|
80 |
+
html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True)
|
81 |
+
styled_html = f"<style>mark.entity {{ display: inline-block; }}</style><div style='overflow-x:auto;'>{html}</div>"
|
82 |
+
st.markdown(styled_html, unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
else:
|
84 |
+
st.info("No entities recognized in the input.")
|
|
|
|
|
|
|
|
|
|
|
|