Spaces:
Sleeping
Sleeping
""" | |
Runs evaluation (RQ1–RQ4, statistical tests, plots) on previously annotated | |
pipeline outputs that include `human_correct` and `human_faithful`. | |
Assumes outputs were generated using `separate_for_annotation.py` and | |
subsequently annotated. | |
""" | |
import argparse | |
import json | |
import logging | |
import itertools | |
from pathlib import Path | |
import numpy as np | |
import yaml | |
import matplotlib.pyplot as plt | |
from evaluation.stats import ( | |
corr_ci, | |
wilcoxon_signed_rank, | |
holm_bonferroni, | |
conditional_failure_rate, | |
chi2_error_propagation, | |
delta_metric, | |
) | |
from evaluation.utils.logger import init_logging | |
def read_jsonl(path: Path): | |
with path.open() as f: | |
return [json.loads(line) for line in f] | |
def save_yaml(path: Path, obj: dict): | |
path.parent.mkdir(parents=True, exist_ok=True) | |
path.write_text(yaml.safe_dump(obj, sort_keys=False)) | |
def agg_mean(rows: list[dict]) -> dict: | |
keys = rows[0]["metrics"].keys() | |
return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys} | |
def rq1_correlation(rows): | |
if "human_correct" not in rows[0] or rows[0]["human_correct"] is None: | |
return {} | |
retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}] | |
gold = [1.0 if r["human_correct"] else 0.0 for r in rows] | |
out = {} | |
for k in retrieval_keys: | |
vec = [r["metrics"][k] for r in rows] | |
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95) | |
out[k] = dict(r=r, ci=[lo, hi], p=p) | |
return out | |
def rq2_faithfulness(rows): | |
if "human_faithful" not in rows[0] or rows[0]["human_faithful"] is None: | |
return {} | |
faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))] | |
gold = [r["human_faithful"] for r in rows] | |
out = {} | |
for k in faith_keys: | |
vec = [r["metrics"][k] for r in rows] | |
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95) | |
out[k] = dict(r=r, ci=[lo, hi], p=p) | |
return out | |
def rq3_error_propagation(rows): | |
if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]: | |
return {} | |
ret_err = [r["retrieval_error"] for r in rows] | |
halluc = [r["hallucination"] for r in rows] | |
return { | |
"conditional": conditional_failure_rate(ret_err, halluc), | |
"chi2": chi2_error_propagation(ret_err, halluc), | |
} | |
def rq4_robustness(orig_rows, pert_rows): | |
if pert_rows is None: | |
return {} | |
metrics = orig_rows[0]["metrics"].keys() | |
out = {} | |
for m in metrics: | |
d, eff = delta_metric( | |
[r["metrics"][m] for r in orig_rows], | |
[r["metrics"][m] for r in pert_rows], | |
) | |
out[m] = dict(delta=d, cohen_d=eff) | |
return out | |
def scatter_mrr_vs_correct(rows, path: Path): | |
x = [r["metrics"].get("mrr", np.nan) for r in rows] | |
y = [1 if r.get("human_correct") else 0 for r in rows] | |
plt.figure() | |
plt.scatter(x, y, alpha=0.5) | |
plt.xlabel("MRR"); plt.ylabel("Correct (1)") | |
plt.title("MRR vs. Human Correctness") | |
plt.tight_layout(); plt.savefig(path); plt.close() | |
def main(argv=None): | |
ap = argparse.ArgumentParser() | |
ap.add_argument("--results", nargs="+", type=Path, required=True, | |
help="One or more annotated results.jsonl files.") | |
ap.add_argument("--outdir", type=Path, default=Path("outputs/grid")) | |
ap.add_argument("--perturbed-suffix", default="_pert.jsonl", | |
help="Looks for this perturbed variant for RQ4.") | |
ap.add_argument("--plots", action="store_true") | |
args = ap.parse_args(argv) | |
init_logging(log_dir=args.outdir / "logs", level="INFO") | |
log = logging.getLogger("resume") | |
historical = {} | |
for res_path in args.results: | |
cfg_name = res_path.parent.name | |
dataset_name = res_path.parent.parent.name | |
log.info("Processing %s on %s", cfg_name, dataset_name) | |
rows = read_jsonl(res_path) | |
pert_path = res_path.with_name(res_path.stem.replace("unlabeled", "pert") + args.perturbed_suffix) | |
pert_rows = read_jsonl(pert_path) if pert_path.exists() else None | |
run_dir = args.outdir / dataset_name / cfg_name | |
run_dir.mkdir(parents=True, exist_ok=True) | |
save_yaml(run_dir / "aggregates.yaml", agg_mean(rows)) | |
save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows)) | |
save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows)) | |
save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows)) | |
if pert_rows: | |
save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows)) | |
if args.plots: | |
scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png") | |
historical[cfg_name] = rows | |
# Pairwise Wilcoxon + Holm correction | |
if len(historical) > 1: | |
names = list(historical) | |
pairs = {} | |
for a, b in itertools.combinations(names, 2): | |
x = [r["metrics"]["rag_score"] for r in historical[a]] | |
y = [r["metrics"]["rag_score"] for r in historical[b]] | |
_, p = wilcoxon_signed_rank(x, y) | |
pairs[f"{a}~{b}"] = p | |
dataset_name = args.results[0].parent.parent.name | |
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_raw.yaml", pairs) | |
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_holm.yaml", holm_bonferroni(pairs)) | |
log.info("Pairwise significance testing complete (rag_score).") | |
if __name__ == "__main__": | |
main() | |