|
import json |
|
from typing import List, Tuple |
|
|
|
import pandas as pd |
|
|
|
from sftp import SpanPredictor |
|
|
|
|
|
def main(): |
|
|
|
|
|
data_file = "/home/p289731/cloned/lome/preproc/evalita_jsonl/evalita_test.jsonl" |
|
models = [ |
|
( |
|
"lome-en", |
|
"/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", |
|
), |
|
( |
|
"lome-it-best", |
|
"/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
for (model_name, model_path) in models: |
|
print("testing model: ", model_name) |
|
predictor = SpanPredictor.from_path(model_path) |
|
|
|
print("=== FD (run 1) ===") |
|
eval_frame_detection(data_file, predictor, model_name=model_name) |
|
|
|
for run in [1, 2]: |
|
print(f"=== BD (run {run}) ===") |
|
eval_boundary_detection(data_file, predictor, run=run) |
|
|
|
for run in [1, 2, 3]: |
|
print(f"=== AC (run {run}) ===") |
|
eval_argument_classification(data_file, predictor, run=run) |
|
|
|
|
|
def predict_frame( |
|
predictor: SpanPredictor, tokens: List[str], predicate_span: Tuple[int, int] |
|
): |
|
_, labels, _ = predictor.force_decode(tokens, child_spans=[predicate_span]) |
|
return labels[0] |
|
|
|
|
|
def eval_frame_detection(data_file, predictor, verbose=False, model_name="_"): |
|
|
|
true_pos = 0 |
|
false_pos = 0 |
|
|
|
out = [] |
|
|
|
with open(data_file, encoding="utf-8") as f: |
|
for sent_id, sent in enumerate(f): |
|
sent_data = json.loads(sent) |
|
|
|
tokens = sent_data["tokens"] |
|
annotation = sent_data["annotations"][0] |
|
|
|
predicate_span = tuple(annotation["span"]) |
|
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] |
|
|
|
frame_gold = annotation["label"] |
|
frame_pred = predict_frame(predictor, tokens, predicate_span) |
|
|
|
if frame_pred == frame_gold: |
|
true_pos += 1 |
|
else: |
|
false_pos += 1 |
|
|
|
out.append({ |
|
"sentence": " ".join(tokens), |
|
"predicate": predicate, |
|
"frame_gold": frame_gold, |
|
"frame_pred": frame_pred |
|
}) |
|
|
|
if verbose: |
|
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") |
|
print(f"\tpredicate: {predicate}") |
|
print(f"\t gold: {frame_gold}") |
|
print(f"\tpredicted: {frame_pred}") |
|
print() |
|
|
|
acc_score = true_pos / (true_pos + false_pos) |
|
print("ACC =", acc_score) |
|
|
|
data_sect = "rai" if "svm_challenge" in data_file else "dev" if "dev" in data_file else "test" |
|
|
|
df_out = pd.DataFrame(out) |
|
df_out.to_csv(f"frame_prediction_output_{model_name}_{data_sect}.csv") |
|
|
|
|
|
def predict_boundaries(predictor: SpanPredictor, tokens, predicate_span, frame): |
|
boundaries, labels, _ = predictor.force_decode( |
|
tokens, parent_span=predicate_span, parent_label=frame |
|
) |
|
out = [] |
|
for bnd, lab in zip(boundaries, labels): |
|
bnd = tuple(bnd) |
|
if bnd == predicate_span and lab == "Target": |
|
continue |
|
out.append(bnd) |
|
return out |
|
|
|
|
|
def get_gold_boundaries(annotation, predicate_span): |
|
return { |
|
tuple(c["span"]) |
|
for c in annotation["children"] |
|
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") |
|
} |
|
|
|
|
|
def eval_boundary_detection(data_file, predictor, run=1, verbose=False): |
|
|
|
assert run in [1, 2] |
|
|
|
true_pos = 0 |
|
false_pos = 0 |
|
false_neg = 0 |
|
|
|
true_pos_tok = 0 |
|
false_pos_tok = 0 |
|
false_neg_tok = 0 |
|
|
|
with open(data_file, encoding="utf-8") as f: |
|
for sent_id, sent in enumerate(f): |
|
sent_data = json.loads(sent) |
|
|
|
tokens = sent_data["tokens"] |
|
annotation = sent_data["annotations"][0] |
|
|
|
predicate_span = tuple(annotation["span"]) |
|
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] |
|
|
|
if run == 1: |
|
frame = predict_frame(predictor, tokens, predicate_span) |
|
else: |
|
frame = annotation["label"] |
|
|
|
boundaries_gold = get_gold_boundaries(annotation, predicate_span) |
|
boundaries_pred = set( |
|
predict_boundaries(predictor, tokens, predicate_span, frame) |
|
) |
|
|
|
sent_true_pos = len(boundaries_gold & boundaries_pred) |
|
sent_false_pos = len(boundaries_pred - boundaries_gold) |
|
sent_false_neg = len(boundaries_gold - boundaries_pred) |
|
true_pos += sent_true_pos |
|
false_pos += sent_false_pos |
|
false_neg += sent_false_neg |
|
|
|
boundary_toks_gold = { |
|
tok_idx |
|
for (start, stop) in boundaries_gold |
|
for tok_idx in range(start, stop + 1) |
|
} |
|
boundary_toks_pred = { |
|
tok_idx |
|
for (start, stop) in boundaries_pred |
|
for tok_idx in range(start, stop + 1) |
|
} |
|
sent_tok_true_pos = len(boundary_toks_gold & boundary_toks_pred) |
|
sent_tok_false_pos = len(boundary_toks_pred - boundary_toks_gold) |
|
sent_tok_false_neg = len(boundary_toks_gold - boundary_toks_pred) |
|
true_pos_tok += sent_tok_true_pos |
|
false_pos_tok += sent_tok_false_pos |
|
false_neg_tok += sent_tok_false_neg |
|
|
|
if verbose: |
|
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") |
|
print(f"\tpredicate: {predicate}") |
|
print(f"\t frame: {frame}") |
|
print(f"\t gold: {boundaries_gold}") |
|
print(f"\tpredicted: {boundaries_pred}") |
|
print(f"\ttp={sent_true_pos}\tfp={sent_false_pos}\tfn={sent_false_neg}") |
|
print( |
|
f"\ttp_t={sent_tok_true_pos}\tfp_t={sent_tok_false_pos}\tfn_t={sent_tok_false_neg}" |
|
) |
|
print() |
|
|
|
prec = true_pos / (true_pos + false_pos) |
|
rec = true_pos / (true_pos + false_neg) |
|
f1_score = 2 * ((prec * rec) / (prec + rec)) |
|
|
|
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") |
|
|
|
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) |
|
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) |
|
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) |
|
|
|
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") |
|
|
|
|
|
def predict_arguments( |
|
predictor: SpanPredictor, tokens, predicate_span, frame, boundaries |
|
): |
|
boundaries = list(sorted(boundaries, key=lambda t: t[0])) |
|
_, labels, _ = predictor.force_decode( |
|
tokens, parent_span=predicate_span, parent_label=frame, child_spans=boundaries |
|
) |
|
out = [] |
|
for bnd, lab in zip(boundaries, labels): |
|
if bnd == predicate_span and lab == "Target": |
|
continue |
|
out.append((bnd, lab)) |
|
return out |
|
|
|
|
|
def eval_argument_classification(data_file, predictor, run=1, verbose=False): |
|
assert run in [1, 2, 3] |
|
|
|
true_pos = 0 |
|
false_pos = 0 |
|
false_neg = 0 |
|
|
|
true_pos_tok = 0 |
|
false_pos_tok = 0 |
|
false_neg_tok = 0 |
|
|
|
with open(data_file, encoding="utf-8") as f: |
|
for sent_id, sent in enumerate(f): |
|
sent_data = json.loads(sent) |
|
|
|
tokens = sent_data["tokens"] |
|
annotation = sent_data["annotations"][0] |
|
|
|
predicate_span = tuple(annotation["span"]) |
|
predicate = tokens[predicate_span[0] : predicate_span[1] + 1] |
|
|
|
|
|
if run == 1: |
|
frame = predict_frame(predictor, tokens, predicate_span) |
|
else: |
|
frame = annotation["label"] |
|
|
|
|
|
if run in [1, 2]: |
|
boundaries = set( |
|
predict_boundaries(predictor, tokens, predicate_span, frame) |
|
) |
|
else: |
|
boundaries = get_gold_boundaries(annotation, predicate_span) |
|
|
|
pred_arguments = predict_arguments( |
|
predictor, tokens, predicate_span, frame, boundaries |
|
) |
|
gold_arguments = { |
|
(tuple(c["span"]), c["label"]) |
|
for c in annotation["children"] |
|
if not (tuple(c["span"]) == predicate_span and c["label"] == "Target") |
|
} |
|
|
|
if verbose: |
|
print(f"Sentence #{sent_id:03}: {' '.join(tokens)}") |
|
print(f"\tpredicate: {predicate}") |
|
print(f"\t frame: {frame}") |
|
print(f"\t gold: {gold_arguments}") |
|
print(f"\tpredicted: {pred_arguments}") |
|
print() |
|
|
|
|
|
for g_bnd, g_label in gold_arguments: |
|
|
|
if (g_bnd, g_label) in pred_arguments: |
|
true_pos += 1 |
|
|
|
else: |
|
false_neg += 1 |
|
for p_bnd, p_label in pred_arguments: |
|
|
|
if (p_bnd, p_label) not in gold_arguments: |
|
false_pos += 1 |
|
|
|
|
|
tok_gold_labels = { |
|
(token, label) |
|
for ((bnd_start, bnd_end), label) in gold_arguments |
|
for token in range(bnd_start, bnd_end + 1) |
|
} |
|
tok_pred_labels = { |
|
(token, label) |
|
for ((bnd_start, bnd_end), label) in pred_arguments |
|
for token in range(bnd_start, bnd_end + 1) |
|
} |
|
for g_tok, g_tok_label in tok_gold_labels: |
|
if (g_tok, g_tok_label) in tok_pred_labels: |
|
true_pos_tok += 1 |
|
else: |
|
false_neg_tok += 1 |
|
for p_tok, p_tok_label in tok_pred_labels: |
|
if (p_tok, p_tok_label) not in tok_gold_labels: |
|
false_pos_tok += 1 |
|
|
|
prec = true_pos / (true_pos + false_pos) |
|
rec = true_pos / (true_pos + false_neg) |
|
f1_score = 2 * ((prec * rec) / (prec + rec)) |
|
|
|
print(f"P/R/F=\n{prec}\t{rec}\t{f1_score}") |
|
|
|
tok_prec = true_pos_tok / (true_pos_tok + false_pos_tok) |
|
tok_rec = true_pos_tok / (true_pos_tok + false_neg_tok) |
|
tok_f1 = 2 * ((tok_prec * tok_rec) / (tok_prec + tok_rec)) |
|
|
|
print(f"Pt/Rt/Ft=\n{tok_prec}\t{tok_rec}\t{tok_f1}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|