|
from sftp import SpanPredictor |
|
import spacy |
|
|
|
from flask import Flask, request, render_template, jsonify, redirect, abort, session |
|
|
|
import sys |
|
import dataclasses |
|
from typing import List, Optional, Dict, Any |
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
class FrameAnnotation: |
|
tokens: List[str] = dataclasses.field(default_factory=list) |
|
pos: List[str] = dataclasses.field(default_factory=list) |
|
|
|
|
|
@dataclasses.dataclass |
|
class MultiLabelAnnotation(FrameAnnotation): |
|
frame_list: List[List[str]] = dataclasses.field(default_factory=list) |
|
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) |
|
|
|
def to_txt(self): |
|
for i, tok in enumerate(self.tokens): |
|
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" |
|
|
|
|
|
|
|
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: |
|
labels = [[] for _ in sentence] |
|
|
|
for struct_id, struct in structures.items(): |
|
tgt_span = struct["target"] |
|
frame = struct["frame"] |
|
for i in range(tgt_span[0], tgt_span[1] + 1): |
|
if i >= len(labels): |
|
continue |
|
labels[i].append(f"T:{frame}@{struct_id:02}") |
|
for role in struct["roles"]: |
|
role_span = role["boundary"] |
|
role_label = role["label"] |
|
for i in range(role_span[0], role_span[1] + 1): |
|
if i >= len(labels): |
|
continue |
|
prefix = "B" if i == role_span[0] else "I" |
|
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") |
|
return labels |
|
|
|
def make_prediction(sentence, spacy_model, predictor): |
|
spacy_doc = spacy_model(sentence) |
|
tokens = [t.text for t in spacy_doc] |
|
tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens) |
|
|
|
frame_structures = {} |
|
|
|
for i, (tgt, frm, fr_proba) in enumerate(sorted(zip(tgt_spans, fr_labels, fr_probas), key=lambda t: t[0][0])): |
|
if frm.startswith("@@"): |
|
continue |
|
if frm.upper() == frm: |
|
continue |
|
if fr_proba.max() != 1.0: |
|
continue |
|
|
|
arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm) |
|
|
|
frame_structures[i] = { |
|
"target": tgt, |
|
"frame": frm, |
|
"roles": [ |
|
{"boundary": bnd, "label": label} |
|
for bnd, label, probas in zip(arg_spans, arg_labels, label_probas) |
|
if label != "Target" and max(probas) == 1.0 |
|
] |
|
} |
|
|
|
return MultiLabelAnnotation( |
|
tokens=tokens, |
|
pos=[t.pos_ for t in spacy_doc], |
|
frame_list=convert_to_seq_labels(tokens, frame_structures), |
|
lu_list=[None for _ in tokens] |
|
) |
|
|
|
|
|
|
|
predictor = SpanPredictor.from_path("model.mod.tar.gz") |
|
nlp = spacy.load("it_core_news_md") |
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route("/analyze") |
|
def analyze(): |
|
text = request.args.get("text") |
|
analyses = [] |
|
for sentence in text.split("\n"): |
|
analyses.append(make_prediction(sentence, nlp, predictor)) |
|
|
|
return jsonify({ |
|
"result": "OK", |
|
"analyses": [dataclasses.asdict(an) for an in analyses] |
|
}) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
if len(sys.argv) > 1: |
|
host = sys.argv[1] |
|
else: |
|
host = "127.0.0.1" |
|
|
|
app.run(host=host, debug=False, port=9090) |
|
|