RAG_Eval / scripts /analysis.py
Rom89823974978's picture
Updated codebase
12409b1
"""
Runs evaluation (RQ1–RQ4, statistical tests, plots) on previously annotated
pipeline outputs that include `human_correct` and `human_faithful`.
Assumes outputs were generated using `separate_for_annotation.py` and
subsequently annotated.
"""
import argparse
import json
import logging
import itertools
from pathlib import Path
import numpy as np
import yaml
import matplotlib.pyplot as plt
from evaluation.stats import (
corr_ci,
wilcoxon_signed_rank,
holm_bonferroni,
conditional_failure_rate,
chi2_error_propagation,
delta_metric,
)
from evaluation.utils.logger import init_logging
def read_jsonl(path: Path):
with path.open() as f:
return [json.loads(line) for line in f]
def save_yaml(path: Path, obj: dict):
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(yaml.safe_dump(obj, sort_keys=False))
def agg_mean(rows: list[dict]) -> dict:
keys = rows[0]["metrics"].keys()
return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
def rq1_correlation(rows):
if "human_correct" not in rows[0] or rows[0]["human_correct"] is None:
return {}
retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}]
gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
out = {}
for k in retrieval_keys:
vec = [r["metrics"][k] for r in rows]
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
out[k] = dict(r=r, ci=[lo, hi], p=p)
return out
def rq2_faithfulness(rows):
if "human_faithful" not in rows[0] or rows[0]["human_faithful"] is None:
return {}
faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))]
gold = [r["human_faithful"] for r in rows]
out = {}
for k in faith_keys:
vec = [r["metrics"][k] for r in rows]
r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
out[k] = dict(r=r, ci=[lo, hi], p=p)
return out
def rq3_error_propagation(rows):
if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]:
return {}
ret_err = [r["retrieval_error"] for r in rows]
halluc = [r["hallucination"] for r in rows]
return {
"conditional": conditional_failure_rate(ret_err, halluc),
"chi2": chi2_error_propagation(ret_err, halluc),
}
def rq4_robustness(orig_rows, pert_rows):
if pert_rows is None:
return {}
metrics = orig_rows[0]["metrics"].keys()
out = {}
for m in metrics:
d, eff = delta_metric(
[r["metrics"][m] for r in orig_rows],
[r["metrics"][m] for r in pert_rows],
)
out[m] = dict(delta=d, cohen_d=eff)
return out
def scatter_mrr_vs_correct(rows, path: Path):
x = [r["metrics"].get("mrr", np.nan) for r in rows]
y = [1 if r.get("human_correct") else 0 for r in rows]
plt.figure()
plt.scatter(x, y, alpha=0.5)
plt.xlabel("MRR"); plt.ylabel("Correct (1)")
plt.title("MRR vs. Human Correctness")
plt.tight_layout(); plt.savefig(path); plt.close()
def main(argv=None):
ap = argparse.ArgumentParser()
ap.add_argument("--results", nargs="+", type=Path, required=True,
help="One or more annotated results.jsonl files.")
ap.add_argument("--outdir", type=Path, default=Path("outputs/grid"))
ap.add_argument("--perturbed-suffix", default="_pert.jsonl",
help="Looks for this perturbed variant for RQ4.")
ap.add_argument("--plots", action="store_true")
args = ap.parse_args(argv)
init_logging(log_dir=args.outdir / "logs", level="INFO")
log = logging.getLogger("resume")
historical = {}
for res_path in args.results:
cfg_name = res_path.parent.name
dataset_name = res_path.parent.parent.name
log.info("Processing %s on %s", cfg_name, dataset_name)
rows = read_jsonl(res_path)
pert_path = res_path.with_name(res_path.stem.replace("unlabeled", "pert") + args.perturbed_suffix)
pert_rows = read_jsonl(pert_path) if pert_path.exists() else None
run_dir = args.outdir / dataset_name / cfg_name
run_dir.mkdir(parents=True, exist_ok=True)
save_yaml(run_dir / "aggregates.yaml", agg_mean(rows))
save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows))
save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows))
save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows))
if pert_rows:
save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows))
if args.plots:
scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png")
historical[cfg_name] = rows
# Pairwise Wilcoxon + Holm correction
if len(historical) > 1:
names = list(historical)
pairs = {}
for a, b in itertools.combinations(names, 2):
x = [r["metrics"]["rag_score"] for r in historical[a]]
y = [r["metrics"]["rag_score"] for r in historical[b]]
_, p = wilcoxon_signed_rank(x, y)
pairs[f"{a}~{b}"] = p
dataset_name = args.results[0].parent.parent.name
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_raw.yaml", pairs)
save_yaml(args.outdir / dataset_name / "wilcoxon_rag_holm.yaml", holm_bonferroni(pairs))
log.info("Pairwise significance testing complete (rag_score).")
if __name__ == "__main__":
main()