GliLem / demo.py
adorkin's picture
Upload demo.py
514a541 verified
from gliner import GLiNER
import gradio as gr
import nltk
from rule_processor import RuleProcessor
from vabamorf_lemmatizer import Lemmatizer
from utils import sentence_to_spans
nltk.download("punkt_tab")
examples = [
"4. koha tõenäsus on täpselt 0, seda sõltumata lisakoha tulekust või mittetulekust.",
"WordPressi puhul tasub see sokutada oma kujundusteema kataloogi ning kui lisada functions.php-sse järgmised kaks rida peakski kõik toimima:",
]
rule_processor = RuleProcessor()
model = GLiNER.from_pretrained("tartuNLP/glilem-vabamorf-disambiguator")
lemmatizer = Lemmatizer(
disambiguate=False, use_context=False, proper_name=True, separate_punctuation=True
)
def process_text(text):
lemmas, tokens = lemmatizer(text, return_tokens=True)
lemmas = [list(set(el)) for el in lemmas]
tokens = [el[0] for el in tokens]
# serves as input for GliNER to remain consistent with Vabamorf tokenization
processed_text = " ".join(tokens)
labels = []
# contains the token id for each span
span_to_token_id = sentence_to_spans(tokens)
# produce a transofrmation rule for each lemma candidate
for token, lemma_list in zip(tokens, lemmas):
for lemma in lemma_list:
labels.append(
rule_processor.gen_lemma_rule(form=token, lemma=lemma, allow_copy=True)
)
# we only consider unique rules
labels = list(set(labels))
predicted_entities = model.predict_entities(
text=processed_text, labels=labels, flat_ner=True, threshold=0.5
)
predictions = tokens.copy()
for entity in predicted_entities:
cur_start = entity["start"]
cur_end = entity["end"]
token = processed_text[cur_start:cur_end]
if f"{cur_start}-{cur_end}" in span_to_token_id:
token_id = span_to_token_id[f"{cur_start}-{cur_end}"]
token = tokens[token_id]
# if there are multiple lemma candidates, apply the highest scoring rule
if len(lemmas[token_id]) > 1:
result = rule_processor.apply_lemma_rule(token, entity["label"])
# otherwise, we trust the Vabamorf lemma
else:
result = lemmas[token_id][0]
predictions[token_id] = result
# store labels to highlight changed word forms
lemma_labels = []
for pred, token in zip(predictions, tokens):
lemma_labels.append(pred != token)
# expected input format for HighlightedText component
processed_entities = {
"text": processed_text,
"entities": [
{
"entity": entity["label"],
"word": entity["text"],
"start": entity["start"],
"end": entity["end"],
"score": entity["score"],
}
for entity in predicted_entities
],
}
processed_lemmas = [(pred, label) for pred, label in zip(predictions, lemma_labels)]
return processed_entities, processed_lemmas
if __name__ == "__main__":
with gr.Blocks() as demo:
input_text = gr.Textbox(
label="Text input", placeholder="Enter your text in Estonian here"
)
label_output = gr.HighlightedText(label="Predicted Transformation Rules")
lemma_output = gr.HighlightedText(label="Predicted Lemmas")
submit_btn = gr.Button("Submit")
input_text.submit(
fn=process_text, inputs=input_text, outputs=[label_output, lemma_output]
)
submit_btn.click(
fn=process_text, inputs=input_text, outputs=[label_output, lemma_output]
)
examples = gr.Examples(
examples,
fn=process_text,
inputs=input_text,
outputs=[label_output, lemma_output],
cache_examples=False,
)
theme = gr.themes.Base()
demo.launch()