|
from typing import Any, Dict, List, Optional |
|
import dataclasses |
|
import glob |
|
import os |
|
import sys |
|
import json |
|
|
|
import spacy |
|
from spacy.language import Language |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
@dataclasses.dataclass |
|
class FrameAnnotation: |
|
tokens: List[str] = dataclasses.field(default_factory=list) |
|
pos: List[str] = dataclasses.field(default_factory=list) |
|
|
|
|
|
@dataclasses.dataclass |
|
class MultiLabelAnnotation(FrameAnnotation): |
|
frame_list: List[List[str]] = dataclasses.field(default_factory=list) |
|
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) |
|
|
|
def to_txt(self): |
|
for i, tok in enumerate(self.tokens): |
|
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" |
|
|
|
|
|
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: |
|
labels = [[] for _ in sentence] |
|
|
|
for struct_id, struct in structures.items(): |
|
tgt_span = struct["target"] |
|
frame = struct["frame"] |
|
|
|
for i in range(tgt_span[0], tgt_span[1] + 1): |
|
labels[i].append(f"T:{frame}@{struct_id:02}") |
|
for role in struct["roles"]: |
|
role_span = role["boundary"] |
|
role_label = role["label"] |
|
for i in range(role_span[0], role_span[1] + 1): |
|
prefix = "B" if i == role_span[0] else "I" |
|
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") |
|
return labels |
|
|
|
|
|
def predict_combined( |
|
spacy_model: Language, |
|
sentences: List[str], |
|
tgt_predictor: SpanPredictor, |
|
frm_predictor: SpanPredictor, |
|
bnd_predictor: SpanPredictor, |
|
arg_predictor: SpanPredictor, |
|
) -> List[MultiLabelAnnotation]: |
|
|
|
annotations_out = [] |
|
|
|
for sent_idx, sent in enumerate(sentences): |
|
|
|
sent = sent.strip() |
|
|
|
print(f"Processing sent with idx={sent_idx}: {sent}") |
|
|
|
doc = spacy_model(sent) |
|
sent_tokens = [t.text for t in doc] |
|
|
|
tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens) |
|
|
|
frame_structures = {} |
|
|
|
for i, span in enumerate(tgt_spans): |
|
span = tuple(span) |
|
_, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span]) |
|
frame = fr_labels[0] |
|
if frame == "@@VIRTUAL_ROOT@@@": |
|
continue |
|
|
|
boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame) |
|
_, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries) |
|
|
|
frame_structures[i] = { |
|
"target": span, |
|
"frame": frame, |
|
"roles": [ |
|
{"boundary": bnd, "label": label} |
|
for bnd, label in zip(boundaries, arg_labels) |
|
if label != "Target" |
|
] |
|
} |
|
annotations_out.append(MultiLabelAnnotation( |
|
tokens=sent_tokens, |
|
pos=[t.pos_ for t in doc], |
|
frame_list=convert_to_seq_labels(sent_tokens, frame_structures), |
|
lu_list=[None for _ in sent_tokens] |
|
)) |
|
return annotations_out |
|
|
|
|
|
def main(input_folder): |
|
|
|
print("Loading spaCy model ...") |
|
nlp = spacy.load("it_core_news_md") |
|
|
|
print("Loading predictors ...") |
|
zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0) |
|
ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0) |
|
|
|
|
|
print("Reading input files ...") |
|
for file in glob.glob(os.path.join(input_folder, "*.txt")): |
|
print(file) |
|
with open(file, encoding="utf-8") as f: |
|
sentences = list(f) |
|
|
|
annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor) |
|
|
|
out_name = os.path.splitext(os.path.basename(file))[0] |
|
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out: |
|
for ann in annotations: |
|
for line in ann.to_txt(): |
|
f_out.write(line + os.linesep) |
|
f_out.write(os.linesep) |
|
|
|
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out: |
|
json.dump([dataclasses.asdict(ann) for ann in annotations], f_out) |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1]) |
|
|