|
from gliner import GLiNER |
|
import gradio as gr |
|
import nltk |
|
|
|
from rule_processor import RuleProcessor |
|
from vabamorf_lemmatizer import Lemmatizer |
|
from utils import sentence_to_spans |
|
|
|
|
|
nltk.download("punkt_tab") |
|
|
|
examples = [ |
|
"4. koha tõenäsus on täpselt 0, seda sõltumata lisakoha tulekust või mittetulekust.", |
|
"WordPressi puhul tasub see sokutada oma kujundusteema kataloogi ning kui lisada functions.php-sse järgmised kaks rida peakski kõik toimima:", |
|
] |
|
|
|
rule_processor = RuleProcessor() |
|
model = GLiNER.from_pretrained("tartuNLP/glilem-vabamorf-disambiguator") |
|
lemmatizer = Lemmatizer( |
|
disambiguate=False, use_context=False, proper_name=True, separate_punctuation=True |
|
) |
|
|
|
|
|
def process_text(text): |
|
lemmas, tokens = lemmatizer(text, return_tokens=True) |
|
lemmas = [list(set(el)) for el in lemmas] |
|
tokens = [el[0] for el in tokens] |
|
|
|
processed_text = " ".join(tokens) |
|
labels = [] |
|
|
|
span_to_token_id = sentence_to_spans(tokens) |
|
|
|
for token, lemma_list in zip(tokens, lemmas): |
|
for lemma in lemma_list: |
|
labels.append( |
|
rule_processor.gen_lemma_rule(form=token, lemma=lemma, allow_copy=True) |
|
) |
|
|
|
labels = list(set(labels)) |
|
predicted_entities = model.predict_entities( |
|
text=processed_text, labels=labels, flat_ner=True, threshold=0.5 |
|
) |
|
|
|
predictions = tokens.copy() |
|
for entity in predicted_entities: |
|
cur_start = entity["start"] |
|
cur_end = entity["end"] |
|
token = processed_text[cur_start:cur_end] |
|
if f"{cur_start}-{cur_end}" in span_to_token_id: |
|
token_id = span_to_token_id[f"{cur_start}-{cur_end}"] |
|
token = tokens[token_id] |
|
|
|
if len(lemmas[token_id]) > 1: |
|
result = rule_processor.apply_lemma_rule(token, entity["label"]) |
|
|
|
else: |
|
result = lemmas[token_id][0] |
|
predictions[token_id] = result |
|
|
|
lemma_labels = [] |
|
for pred, token in zip(predictions, tokens): |
|
lemma_labels.append(pred != token) |
|
|
|
processed_entities = { |
|
"text": processed_text, |
|
"entities": [ |
|
{ |
|
"entity": entity["label"], |
|
"word": entity["text"], |
|
"start": entity["start"], |
|
"end": entity["end"], |
|
"score": entity["score"], |
|
} |
|
for entity in predicted_entities |
|
], |
|
} |
|
processed_lemmas = [(pred, label) for pred, label in zip(predictions, lemma_labels)] |
|
|
|
return processed_entities, processed_lemmas |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
with gr.Blocks() as demo: |
|
input_text = gr.Textbox( |
|
label="Text input", placeholder="Enter your text in Estonian here" |
|
) |
|
label_output = gr.HighlightedText(label="Predicted Transformation Rules") |
|
lemma_output = gr.HighlightedText(label="Predicted Lemmas") |
|
submit_btn = gr.Button("Submit") |
|
input_text.submit( |
|
fn=process_text, inputs=input_text, outputs=[label_output, lemma_output] |
|
) |
|
submit_btn.click( |
|
fn=process_text, inputs=input_text, outputs=[label_output, lemma_output] |
|
) |
|
examples = gr.Examples( |
|
examples, |
|
fn=process_text, |
|
inputs=input_text, |
|
outputs=[label_output, lemma_output], |
|
cache_examples=False, |
|
) |
|
theme = gr.themes.Base() |
|
demo.launch() |
|
|