Rom89823974978 commited on
Commit
e8c3964
Β·
1 Parent(s): bdb49ae

Added dashboard and experiments

Browse files
README.md CHANGED
@@ -1,5 +1,168 @@
1
- # RAG Evaluation Framework for Regulated Domains - Master's Thesis
 
2
 
3
- This repository contains a modular implementation of an evaluation framework for Retrieval‑Augmented Generation (RAG) systems.
4
 
5
- See `evaluation/` for library code and `tests/` for smoke tests.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Below is a complete **README.md** you can drop into the repository root.
2
+ It walks through the codebase, explains how each layer aligns with the research-proposal objectives, and gives practical β€œgetting-started” steps for building indexes, running experiments, and producing statistical analyses.
3
 
4
+ ---
5
 
6
+ ````markdown
7
+ # Retrieval-Augmented Generation Evaluation Framework
8
+ *(Legal & Financial domains, with full regulatory-grade metrics)*
9
+
10
+ > **Project context** – This code implements the software artefacts promised in the research proposal
11
+ > β€œ**Toward Comprehensive Evaluation of Retrieval-Augmented Generation Systems in Regulated Domains**.”
12
+ > Each folder corresponds to a work-package from the proposal: retrieval pipelines, metric library
13
+ > , robustness & statistical analysis, plus automation for Docker / CI.
14
+
15
+ ---
16
+
17
+ ## 1. Quick start
18
+
19
+ ```bash
20
+ # Clone and bootstrap
21
+ git clone https://github.com/<your-org>/rag-eval-framework.git
22
+ cd rag-eval-framework
23
+ python -m venv .venv && source .venv/bin/activate
24
+ pip install -r requirements.txt
25
+ pre-commit install # optional: local lint hooks
26
+
27
+ # Download / prepare a small corpus (makes ~200 docs)
28
+ bash scripts/download_data.sh
29
+
30
+ # Build sparse & dense indexes automatically on first run
31
+ python scripts/run_experiments.py \
32
+ --config configs/pipeline_hybrid_ce.yaml \
33
+ --queries data/sample_queries.jsonl
34
+ ````
35
+
36
+ The first invocation embeds documents, builds a **FAISS** dense index, and a **Pyserini** (Lucene) sparse index. Subsequent runs reuse them.
37
+
38
+ ---
39
+
40
+ ## 2. Repository layout
41
+
42
+ ```
43
+ evaluation/ ← βš™οΈ Core library
44
+ β”œβ”€β”€ config.py β‡’ Typed dataclasses (retriever, generator, stats, reranker)
45
+ β”œβ”€β”€ pipeline.py β‡’ Orchestrates retrieval β†’ (optional) re-ranking β†’ generation
46
+ β”‚ └── … logs every stage to dict β†’ downstream eval
47
+ β”œβ”€β”€ retrievers/ β‡’ BM25, Dense (Sentence-Transformers + FAISS), Hybrid
48
+ β”œβ”€β”€ rerankers/ β‡’ Cross-encoder re-ranker (optional second stage)
49
+ β”œβ”€β”€ generators/ β‡’ Hugging Face generator wrapper (T5/Flan/BART…)
50
+ β”œβ”€β”€ metrics/ β‡’ Retrieval, generation, composite RAG score
51
+ └── stats/ β‡’ Correlation, significance, robustness utilities
52
+ configs/ ← YAML templates (pipeline & stats settings)
53
+ scripts/ ← CLI helpers: run_experiments.py, download_data.sh …
54
+ tests/ ← PyTest smoke tests cover every public module
55
+ .github/workflows/ci.yml ← Lint + tests on push / PR
56
+ Dockerfile ← Slim runtime ready for reproducibility
57
+ ```
58
+
59
+ ---
60
+
61
+ ## 3. How each module maps to proposal tasks
62
+
63
+ | Proposal section | Code artefact | Purpose |
64
+ | -------------------------------------- | ----------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- |
65
+ | **Retrievers** (BM25, dense, hybrid) | `evaluation/retrievers/` | Implements **RQ1** experiments on classic vs. dense retrieval. Auto-builds indexes to ease replication. |
66
+ | **Generator** (Fixed seq2seq backbone) | `evaluation/generators/` | Holds the controlled decoding backend so retrieval changes are isolated. |
67
+ | **Cross-encoder re-ranker** | `evaluation/rerankers/` | Optional β€œadvanced RAG” from Fig. 2 of proposal; improves evidence precision. |
68
+ | **Metric taxonomy** | `evaluation/metrics/` | Classical IR metrics, semantic generation scores, and composite `rag_score` per WP3. |
69
+ | **Statistical tests & sensitivity** | `evaluation/stats/` + `StatsConfig` | Spearman/ Kendall correlations (**RQ1, RQ2**), Wilcoxon + Holm-Bonferroni (**RQ2**), error-propagation χ² and robustness deltas (**RQ3, RQ4**). |
70
+ | **Reproducibility** | Dockerfile, CI, pre-commit | Meets EU AI Act’s β€œtechnical documentation & traceability” clauses (Articles 14-15). |
71
+
72
+ ---
73
+
74
+ ## 4. Configuration at a glance
75
+
76
+ ```yaml
77
+ # configs/pipeline_hybrid_ce.yaml
78
+ retriever:
79
+ name: hybrid # bm25 | dense | hybrid
80
+ bm25_index: indexes/legal_bm25
81
+ faiss_index: indexes/legal_dense.faiss
82
+ doc_store: data/legal_docs.jsonl
83
+ top_k: 10
84
+ alpha: 0.6
85
+
86
+ reranker:
87
+ enable: true # cross-encoder stage
88
+ model_name: cross-encoder/ms-marco-MiniLM-L-6-v2
89
+ first_stage_k: 50
90
+ final_k: 10
91
+ device: cuda:0
92
+
93
+ generator:
94
+ model_name: google/flan-t5-base
95
+ device: cuda:0
96
+ max_new_tokens: 256
97
+ temperature: 0.0
98
+
99
+ stats:
100
+ correlation_method: spearman
101
+ n_boot: 5000
102
+ ci: 0.95
103
+ wilcoxon_alternative: two-sided
104
+ multiple_correction: holm-bonferroni
105
+ alpha: 0.05
106
+ ```
107
+
108
+ All fields are documented in `evaluation/config.py`. You can override any flag via CLI (`--retriever.top_k 20`) if you parse with Hydra or OmegaConf.
109
+
110
+ ---
111
+
112
+ ## 5. Index generation details
113
+
114
+ * **Sparse (BM25 / Lucene)**
115
+ If `bm25_index` dir is absent, the `BM25Retriever` calls *Pyserini’s* CLI to build it from `doc_store` (JSONL with `{"id", "text"}`).
116
+ * **Dense (FAISS)**
117
+ Likewise, `DenseRetriever` embeds every document using the Sentence-Transformers model in the config, normalises vectors, and builds an IP-metric FAISS index.
118
+
119
+ Both steps cache artefacts, so future runs start instantly.
120
+
121
+ ---
122
+
123
+ ## 6. Running the statistical evaluation
124
+
125
+ Each experiment run dumps a JSONL (`results.jsonl`) with per-query fields:
126
+
127
+ ```jsonc
128
+ {
129
+ "question": "...",
130
+ "answer": "...",
131
+ "contexts": ["..."],
132
+ "metrics": {
133
+ "precision@10": 0.9,
134
+ "rag_score": 0.71,
135
+ ...
136
+ },
137
+ "human_correct": true, // optional gold labels
138
+ "human_faithful": 0.8 // optional expert rating 0-1
139
+ }
140
+ ```
141
+
142
+ You can feed that into a notebook or CLI script:
143
+
144
+ ```python
145
+ from evaluation.stats import (
146
+ corr_ci, wilcoxon_signed_rank, holm_bonferroni,
147
+ delta_metric, conditional_failure_rate
148
+ )
149
+ from evaluation import StatsConfig
150
+
151
+ cfg = StatsConfig(n_boot=5000)
152
+ # example: correlation of MRR vs. human correctness
153
+ mrr = [r["metrics"]["mrr"] for r in rows]
154
+ gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
155
+ rho, (lo, hi), p = corr_ci(mrr, gold, method=cfg.correlation_method, n_boot=cfg.n_boot)
156
+ print(f"Spearman ρ={rho:.2f} 95% CI=({lo:.2f},{hi:.2f}) p={p:.3g}")
157
+ ```
158
+
159
+ All statistical primitives are implemented in pure NumPy+SciPy, ensuring compatibility with lightweight Docker images.
160
+
161
+ ---
162
+
163
+ ### Happy evaluating!
164
+
165
+ Questions or suggestions? Open an issue or discussion on the GitHub repo.
166
+
167
+ ```
168
+ ```
evaluation/__init__.py CHANGED
@@ -13,3 +13,4 @@ The public API re‑exports :class:`evaluation.pipeline.RAGPipeline`.
13
  """
14
 
15
  from .pipeline import RAGPipeline, PipelineConfig # noqa: F401
 
 
13
  """
14
 
15
  from .pipeline import RAGPipeline, PipelineConfig # noqa: F401
16
+ from .config import LoggingConfig
evaluation/config.py CHANGED
@@ -4,6 +4,13 @@ from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Optional, Literal
6
 
 
 
 
 
 
 
 
7
  @dataclass
8
  class CrossEncoderConfig:
9
  enable: bool = False # master switch
@@ -64,6 +71,7 @@ class StatsConfig:
64
  @dataclass
65
  class PipelineConfig:
66
  """Top‑level pipeline configuration."""
 
67
  reranker: CrossEncoderConfig = CrossEncoderConfig()
68
  retriever: RetrieverConfig = RetrieverConfig()
69
  generator: GeneratorConfig = GeneratorConfig()
 
4
  from pathlib import Path
5
  from typing import Optional, Literal
6
 
7
+ @dataclass
8
+ class LoggingConfig:
9
+ log_dir: Path = Path("logs")
10
+ level: str = "INFO" # DEBUG | INFO | WARNING | ERROR | CRITICAL
11
+ max_mb: int = 5 # per-file size before rotation
12
+ backups: int = 5 # number of rotated files to keep
13
+
14
  @dataclass
15
  class CrossEncoderConfig:
16
  enable: bool = False # master switch
 
71
  @dataclass
72
  class PipelineConfig:
73
  """Top‑level pipeline configuration."""
74
+ logging: LoggingConfig = LoggingConfig()
75
  reranker: CrossEncoderConfig = CrossEncoderConfig()
76
  retriever: RetrieverConfig = RetrieverConfig()
77
  generator: GeneratorConfig = GeneratorConfig()
evaluation/stats/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  """Statistical utilities for analysis scripts."""
2
 
 
3
  from .correlation import corr_ci
4
  from .significance import wilcoxon_signed_rank, holm_bonferroni
5
  from .robustness import (
 
1
  """Statistical utilities for analysis scripts."""
2
 
3
+ from ..config import StatsConfig
4
  from .correlation import corr_ci
5
  from .significance import wilcoxon_signed_rank, holm_bonferroni
6
  from .robustness import (
evaluation/utils/logger.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Centralised logging initialisation (console + rotating file)."""
2
+
3
+ from __future__ import annotations
4
+ import logging
5
+ import logging.handlers
6
+ import os
7
+ import sys
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ __all__ = ["init_logging"]
13
+
14
+
15
+ def init_logging(
16
+ *,
17
+ log_dir: str | os.PathLike = "logs",
18
+ level: str | int = "INFO",
19
+ fmt: str = "%(asctime)s | %(levelname)s | %(name)s | %(message)s",
20
+ max_mb: int = 5,
21
+ backups: int = 5,
22
+ ) -> Path:
23
+ """Configure root logger for both console *and* rotating-file output.
24
+
25
+ Returns
26
+ -------
27
+ Path to the log file.
28
+ """
29
+ log_dir = Path(log_dir)
30
+ log_dir.mkdir(parents=True, exist_ok=True)
31
+ logfile = log_dir / f"{datetime.now(datetime.timezone.utc):%Y%m%d_%H%M%S}.log"
32
+
33
+ if isinstance(level, str):
34
+ level = logging._nameToLevel.get(level.upper(), logging.INFO)
35
+ formatter = logging.Formatter(fmt)
36
+
37
+ root = logging.getLogger()
38
+ root.setLevel(level)
39
+ root.handlers.clear() # avoid duplicate handlers on re-init
40
+
41
+ # Console
42
+ ch = logging.StreamHandler(sys.stderr)
43
+ ch.setLevel(level)
44
+ ch.setFormatter(formatter)
45
+ root.addHandler(ch)
46
+
47
+ # Rotating file
48
+ fh = logging.handlers.RotatingFileHandler(
49
+ logfile, maxBytes=max_mb * 1024 * 1024, backupCount=backups
50
+ )
51
+ fh.setLevel(level)
52
+ fh.setFormatter(formatter)
53
+ root.addHandler(fh)
54
+
55
+ root.info("Logging initialised. File=%s Level=%s", logfile, logging.getLevelName(level))
56
+ return logfile
scripts/dashboard.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ dashboard.py
4
+ ============
5
+
6
+ Launch with:
7
+ streamlit run scripts/dashboard.py
8
+
9
+ Relies on the directory structure produced by run_grid_experiments.py:
10
+ outputs/grid/<dataset>/<config>/{aggregates.yaml, rq1.yaml, ...}
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import yaml
16
+ from pathlib import Path
17
+
18
+ import pandas as pd
19
+ import streamlit as st
20
+ import matplotlib.pyplot as plt
21
+
22
+ BASE_DIR = Path("outputs/grid") # change if you store runs elsewhere
23
+ METRIC_KEY = "rag_score" # bar/box plots focus on this
24
+
25
+ # --------------------------------------------------------------------- Sidebar
26
+ st.sidebar.title("RAG-Eval Dashboard")
27
+
28
+ if not BASE_DIR.exists():
29
+ st.sidebar.error(f"Folder {BASE_DIR} not found – run experiments first.")
30
+ st.stop()
31
+
32
+ datasets = sorted([p.name for p in BASE_DIR.iterdir() if p.is_dir()])
33
+ dataset = st.sidebar.selectbox("Dataset", datasets)
34
+ conf_dir = BASE_DIR / dataset
35
+ configs = sorted([p.name for p in conf_dir.iterdir() if p.is_dir()])
36
+ sel_cfgs = st.sidebar.multiselect("Configurations", configs, default=configs)
37
+
38
+ if not sel_cfgs:
39
+ st.warning("Select at least one configuration.")
40
+ st.stop()
41
+
42
+ # ---------------------------------------------------------------- Load helpers
43
+ def _yaml(path: Path): return yaml.safe_load(path.read_text())
44
+ def _jsonl(path: Path): return [json.loads(l) for l in path.read_text().splitlines()]
45
+
46
+ # ---------------------------------------------------------------- Main view
47
+ st.title(f"Dataset: {dataset}")
48
+
49
+ # ── Aggregated metrics table ────────────────────────────────────────────────
50
+ agg = {c: _yaml(conf_dir / c / "aggregates.yaml") for c in sel_cfgs}
51
+ agg_df = pd.DataFrame(agg).T
52
+ st.subheader("Aggregated metrics")
53
+ st.dataframe(agg_df, use_container_width=True)
54
+
55
+ # ── Bar chart of rag_score means ────────────────────────────────────────────
56
+ st.subheader(f"Mean {METRIC_KEY}")
57
+ fig, ax = plt.subplots()
58
+ agg_df[METRIC_KEY].plot.bar(ax=ax)
59
+ ax.set_ylabel(METRIC_KEY)
60
+ ax.set_ylim(0, 1)
61
+ st.pyplot(fig)
62
+
63
+ # ── Scatter MRR vs Correctness per config ───────────────────────────────────
64
+ st.subheader("MRR vs Human Correctness")
65
+ cols = st.columns(len(sel_cfgs))
66
+ for col, cfg in zip(cols, sel_cfgs):
67
+ rows = _jsonl(conf_dir / cfg / "results.jsonl")
68
+ x = [r["metrics"].get("mrr", float("nan")) for r in rows]
69
+ y = [1 if r.get("human_correct") else 0 for r in rows]
70
+ fig, ax = plt.subplots()
71
+ ax.scatter(x, y, alpha=0.5)
72
+ ax.set(title=cfg, xlabel="MRR", ylabel="Correct?")
73
+ col.pyplot(fig)
74
+
75
+ # ── Pairwise Wilcoxon-Holm table (rag_score) ────────────────────────────────
76
+ wh_path = conf_dir / "wilcoxon_rag_holm.yaml"
77
+ if wh_path.exists():
78
+ st.subheader("Pairwise Wilcoxon-Holm (rag_score)")
79
+ wh_df = pd.Series(_yaml(wh_path), name="p_adj").to_frame()
80
+ st.dataframe(wh_df)
81
+ else:
82
+ st.info("Wilcoxon table not found – run_grid_experiments.py computes it.")
83
+
84
+ # ── Research-question YAMLs ─────────────────────────────────────────────────
85
+ rq_tabs = st.tabs([f"{cfg}" for cfg in sel_cfgs])
86
+ for tab, cfg in zip(rq_tabs, sel_cfgs):
87
+ with tab:
88
+ for rq in ("rq1", "rq2", "rq3", "rq4"):
89
+ path = conf_dir / cfg / f"{rq}.yaml"
90
+ if path.exists():
91
+ st.markdown(f"**{rq.upper()}**")
92
+ st.json(_yaml(path))
93
+ else:
94
+ st.markdown(f"*{rq.upper()} – not available*")
95
+
96
+ # ── Raw results download ────────────────────────────────────────────────────
97
+ st.sidebar.subheader("Download")
98
+ for cfg in sel_cfgs:
99
+ st.sidebar.download_button(
100
+ label=f"{cfg} results.jsonl",
101
+ data=(conf_dir / cfg / "results.jsonl").read_bytes(),
102
+ file_name=f"{dataset}_{cfg}_results.jsonl",
103
+ mime="application/jsonl",
104
+ )
scripts/run_experiments.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ run_experiments.py
4
+ ==================
5
+
6
+ High-level driver that wires together:
7
+
8
+ 1. YAML / CLI β†’ `PipelineConfig` + `LoggingConfig`
9
+ 2. Initialises dual-sink logging (console + rotating file)
10
+ 3. Builds a `RAGPipeline`
11
+ 4. Streams a list of questions through the pipeline
12
+ 5. Logs progress, writes per-query JSONL results, and
13
+ (optionally) prints aggregate statistics.
14
+
15
+ You can keep it minimal – or expand the marked TODO sections to:
16
+ * compute metrics immediately
17
+ * push results to a tracker (W&B, MLflow, etc.)
18
+ * spawn multiple configs in parallel.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import sys
25
+ from pathlib import Path
26
+ from typing import Any, Dict, Iterable, List, Mapping
27
+
28
+ import yaml
29
+
30
+ from evaluation import (
31
+ PipelineConfig,
32
+ RetrieverConfig,
33
+ GeneratorConfig,
34
+ CrossEncoderConfig,
35
+ StatsConfig,
36
+ LoggingConfig,
37
+ RAGPipeline,
38
+ )
39
+ from evaluation.utils.logger import init_logging
40
+
41
+ from evaluation.stats import (
42
+ corr_ci,
43
+ wilcoxon_signed_rank,
44
+ holm_bonferroni,
45
+ )
46
+
47
+ import matplotlib.pyplot as plt
48
+
49
+ # ──────────────────────────────────────────────────────────────────────────────
50
+ # Helpers
51
+ # ──────────────────────────────────────────────────────────────────────────────
52
+
53
+
54
+ def _merge_dataclass(dc_cls, default, override: Mapping[str, Any]):
55
+ """Return a new *dc_cls* where fields from *override* overwrite *default*."""
56
+ from dataclasses import asdict
57
+
58
+ merged = asdict(default)
59
+ merged.update({k: v for k, v in override.items() if v is not None})
60
+ return dc_cls(**merged)
61
+
62
+
63
+ def _load_pipeline_config(yaml_path: Path | None) -> PipelineConfig:
64
+ """Parse YAML into nested dataclasses; fall back to defaults."""
65
+ if yaml_path is None:
66
+ return PipelineConfig() # all defaults
67
+
68
+ data = yaml.safe_load(yaml_path.read_text())
69
+
70
+ retr_cfg = _merge_dataclass(
71
+ RetrieverConfig(), RetrieverConfig(), data.get("retriever", {})
72
+ )
73
+ gen_cfg = _merge_dataclass(
74
+ GeneratorConfig(), GeneratorConfig(), data.get("generator", {})
75
+ )
76
+ rr_cfg = _merge_dataclass(
77
+ CrossEncoderConfig(), CrossEncoderConfig(), data.get("reranker", {})
78
+ )
79
+ stats_cfg = _merge_dataclass(StatsConfig(), StatsConfig(), data.get("stats", {}))
80
+ log_cfg = _merge_dataclass(LoggingConfig(), LoggingConfig(), data.get("logging", {}))
81
+
82
+ return PipelineConfig(
83
+ retriever=retr_cfg,
84
+ generator=gen_cfg,
85
+ reranker=rr_cfg,
86
+ stats=stats_cfg,
87
+ logging=log_cfg,
88
+ )
89
+
90
+
91
+ def _read_jsonl(path: Path) -> List[Dict[str, Any]]:
92
+ with path.open() as f:
93
+ return [json.loads(line) for line in f]
94
+
95
+
96
+ def _write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]):
97
+ path.parent.mkdir(parents=True, exist_ok=True)
98
+ with path.open("w") as f:
99
+ for row in rows:
100
+ f.write(json.dumps(row) + "\n")
101
+
102
+ # Stats Helper
103
+ def aggregate_metrics(rows: list[dict[str, Any]]) -> dict[str, float]:
104
+ """Return mean of every numeric metric found under row['metrics']."""
105
+ import numpy as np
106
+ keys = rows[0]["metrics"].keys()
107
+ return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
108
+
109
+
110
+ def correlation_with_gold(rows: list[dict[str, Any]], cfg: StatsConfig):
111
+ """Spearman/Kendall correlation between retrieval scores and correctness flag."""
112
+ if "human_correct" not in rows[0]:
113
+ return None # nothing to correlate
114
+ mrr = [r["metrics"].get("mrr", float("nan")) for r in rows]
115
+ gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
116
+ r, (lo, hi), p = corr_ci(
117
+ mrr, gold, method=cfg.correlation_method, n_boot=cfg.n_boot, ci=cfg.ci
118
+ )
119
+ return dict(r=r, ci_low=lo, ci_high=hi, p=p)
120
+
121
+
122
+ def wilcoxon_against_baseline(
123
+ cur: list[dict[str, Any]],
124
+ base: list[dict[str, Any]],
125
+ cfg: StatsConfig,
126
+ ):
127
+ """Paired Wilcoxon + Holm-Bonferroni across all metric keys."""
128
+ from evaluation.stats import wilcoxon_signed_rank, holm_bonferroni
129
+
130
+ assert len(cur) == len(base), "Runs must have same #queries"
131
+ metrics = cur[0]["metrics"].keys()
132
+ p_raw = {}
133
+ for m in metrics:
134
+ cur_m = [r["metrics"][m] for r in cur]
135
+ base_m = [r["metrics"][m] for r in base]
136
+ _, p = wilcoxon_signed_rank(cur_m, base_m, alternative=cfg.wilcoxon_alternative)
137
+ p_raw[m] = p
138
+ return holm_bonferroni(p_raw)
139
+
140
+ # Plot helper
141
+ def save_scatter(rows, out_dir: Path):
142
+ out_dir.mkdir(parents=True, exist_ok=True)
143
+ x = [r["metrics"]["mrr"] for r in rows if "mrr" in r["metrics"]]
144
+ y = [1.0 if r.get("human_correct") else 0.0 for r in rows]
145
+ plt.figure()
146
+ plt.scatter(x, y, alpha=0.6)
147
+ plt.xlabel("MRR")
148
+ plt.ylabel("Correct (1=yes)")
149
+ plt.title("MRR vs. Human Correctness")
150
+ path = out_dir / "mrr_vs_correct.png"
151
+ plt.savefig(path, bbox_inches="tight")
152
+ plt.close()
153
+ return path
154
+
155
+ # ──────────────────────────────────────────────────────────────────────────────
156
+ # Main
157
+ # ──────────────────────────────────────────────────────────────────────────────
158
+ def main(argv: list[str] | None = None) -> None:
159
+ ap = argparse.ArgumentParser(description="Run RAG evaluation experiments.")
160
+ ap.add_argument("--config", type=Path, help="YAML config with pipeline settings")
161
+ ap.add_argument(
162
+ "--queries",
163
+ type=Path,
164
+ required=True,
165
+ help="JSONL file – each line must contain at least {'question': ...}",
166
+ )
167
+ ap.add_argument(
168
+ "--output",
169
+ type=Path,
170
+ default=Path("outputs/results.jsonl"),
171
+ help="Where to write JSONL results",
172
+ )
173
+ ap.add_argument("--dry-run", action="store_true", help="Do not execute pipeline")
174
+ ap.add_argument(
175
+ "--baseline",
176
+ type=Path,
177
+ help="Optional: JSONL with baseline run for significance tests",
178
+ )
179
+ ap.add_argument(
180
+ "--plots",
181
+ action="store_true",
182
+ help="Save diagnostic plots (PNG) alongside results",
183
+ )
184
+ args = ap.parse_args(argv)
185
+
186
+ # 1. Parse configuration
187
+ cfg = _load_pipeline_config(args.config)
188
+
189
+ # 2. Initialise logging (file + stderr)
190
+ init_logging(
191
+ log_dir=cfg.logging.log_dir,
192
+ level=cfg.logging.level,
193
+ max_mb=cfg.logging.max_mb,
194
+ backups=cfg.logging.backups,
195
+ )
196
+
197
+ import logging
198
+
199
+ logger = logging.getLogger(__name__)
200
+ logger.info("Loaded PipelineConfig:\n%s", cfg)
201
+
202
+ # 3. Build pipeline (retrieval β†’ (rerank) β†’ generation)
203
+ pipeline = RAGPipeline(cfg)
204
+
205
+ # 4. Load queries
206
+ rows = _read_jsonl(args.queries)
207
+ logger.info("Loaded %d queries from %s", len(rows), args.queries)
208
+
209
+ if args.dry_run:
210
+ logger.warning("Dry-run flag active – exiting before execution.")
211
+ sys.exit(0)
212
+
213
+ # 5. Execute pipeline
214
+ results: List[Dict[str, Any]] = []
215
+ for i, row in enumerate(rows, 1):
216
+ q = row["question"]
217
+ logger.info("[%d/%d] Q: %s", i, len(rows), q)
218
+ out = pipeline.run(q)
219
+ merged = {**row, **out} # keep any gold labels or metadata
220
+ results.append(merged)
221
+
222
+ # 6. Persist results
223
+ _write_jsonl(args.output, results)
224
+ logger.info("Wrote %d results to %s", len(results), args.output)
225
+
226
+ # 7. Aggregate statistics, significance tests, plots
227
+ agg = aggregate_metrics(results)
228
+ logger.info("Mean metrics: %s", json.dumps(agg, indent=2))
229
+
230
+ corr = correlation_with_gold(results, cfg.stats)
231
+ if corr:
232
+ logger.info(
233
+ "Correlation MRR↔gold %s=%.3f 95%%CI=[%.3f, %.3f] p=%.3g",
234
+ cfg.stats.correlation_method,
235
+ corr["r"],
236
+ corr["ci_low"],
237
+ corr["ci_high"],
238
+ corr["p"],
239
+ )
240
+
241
+ if args.baseline:
242
+ baseline_rows = _read_jsonl(args.baseline)
243
+ p_adj = wilcoxon_against_baseline(results, baseline_rows, cfg.stats)
244
+ logger.info("Wilcoxon vs baseline (Holm-Bonferroni Ξ±=%s): %s", cfg.stats.alpha, p_adj)
245
+
246
+ if args.plots:
247
+ plot_path = save_scatter(results, args.output.parent)
248
+ logger.info("Saved plot β†’ %s", plot_path)
249
+
250
+ if __name__ == "__main__":
251
+ main()
scripts/run_grid_experiments.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ run_grid_experiments.py
4
+ =======================
5
+ Batch driver for *config Γ— dataset* evaluation, including:
6
+
7
+ * RQ1 – Correlation of classical retrieval metrics with factual-correctness
8
+ * RQ2 – Correlation of faithfulness metrics with expert judgements
9
+ * RQ3 – Retrieval-error ➜ hallucination propagation (χ² + conditional rates)
10
+ * RQ4 – Robustness under adversarial perturbations (Ξ”-metrics, Cohen d)
11
+
12
+ Features
13
+ --------
14
+ * Incremental mode – pass **one** new --config, it is compared to all
15
+ previous runs already found under --outdir/<dataset>/.
16
+ * Saves:
17
+ - `results.jsonl`
18
+ - `aggregates.yaml`
19
+ - `rq1.yaml`, `rq2.yaml`, `rq3.yaml`, `rq4.yaml`
20
+ - pairwise Wilcoxon/ Holm tables
21
+ - bar-, box-, scatter-plots (if --plots flag)
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import argparse
27
+ import itertools
28
+ import json
29
+ import logging
30
+ import os
31
+ from pathlib import Path
32
+ from typing import Any, Dict, Iterable, List, Mapping
33
+
34
+ import matplotlib.pyplot as plt
35
+ import numpy as np
36
+ import yaml
37
+
38
+ from evaluation import (
39
+ PipelineConfig,
40
+ RetrieverConfig,
41
+ GeneratorConfig,
42
+ CrossEncoderConfig,
43
+ StatsConfig,
44
+ LoggingConfig,
45
+ RAGPipeline,
46
+ )
47
+ from evaluation.stats import (
48
+ corr_ci,
49
+ wilcoxon_signed_rank,
50
+ holm_bonferroni,
51
+ conditional_failure_rate,
52
+ chi2_error_propagation,
53
+ delta_metric,
54
+ )
55
+ from evaluation.utils.logger import init_logging
56
+
57
+ # ─────────────────────────────── I/O helpers ────────────────────────────────
58
+
59
+
60
+ def read_jsonl(path: Path) -> List[Dict[str, Any]]:
61
+ with path.open() as f:
62
+ return [json.loads(line) for line in f]
63
+
64
+
65
+ def write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]) -> None:
66
+ path.parent.mkdir(parents=True, exist_ok=True)
67
+ with path.open("w") as f:
68
+ for row in rows:
69
+ f.write(json.dumps(row) + "\n")
70
+
71
+
72
+ def save_yaml(path: Path, obj: Mapping[str, Any]) -> None:
73
+ path.parent.mkdir(parents=True, exist_ok=True)
74
+ path.write_text(yaml.safe_dump(obj, sort_keys=False))
75
+
76
+
77
+ # ─────────────────────── config merge (same as earlier) ─────────────────────
78
+
79
+
80
+ def merge_dataclass(dc_cls, override: Mapping[str, Any]):
81
+ from dataclasses import asdict
82
+
83
+ base = asdict(dc_cls())
84
+ base.update({k: v for k, v in override.items() if v is not None})
85
+ return dc_cls(**base)
86
+
87
+
88
+ def load_pipeline_config(yaml_path: Path) -> PipelineConfig:
89
+ data = yaml.safe_load(yaml_path.read_text())
90
+ return PipelineConfig(
91
+ retriever=merge_dataclass(RetrieverConfig, data.get("retriever", {})),
92
+ generator=merge_dataclass(GeneratorConfig, data.get("generator", {})),
93
+ reranker=merge_dataclass(CrossEncoderConfig, data.get("reranker", {})),
94
+ stats=merge_dataclass(StatsConfig, data.get("stats", {})),
95
+ logging=merge_dataclass(LoggingConfig, data.get("logging", {})),
96
+ )
97
+
98
+
99
+ # ───────────────────────────── stats helpers ────────────────────────────────
100
+ def agg_mean(rows: List[dict[str, Any]]) -> dict[str, float]:
101
+ keys = rows[0]["metrics"].keys()
102
+ return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
103
+
104
+
105
+ def rq1_correlation(rows, cfg: StatsConfig):
106
+ if "human_correct" not in rows[0]:
107
+ return {}
108
+ retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}]
109
+ gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
110
+ out = {}
111
+ for k in retrieval_keys:
112
+ vec = [r["metrics"][k] for r in rows]
113
+ r, (lo, hi), p = corr_ci(vec, gold, method=cfg.correlation_method,
114
+ n_boot=cfg.n_boot, ci=cfg.ci)
115
+ out[k] = dict(r=r, ci=[lo, hi], p=p)
116
+ return out
117
+
118
+
119
+ def rq2_faithfulness(rows, cfg: StatsConfig):
120
+ if "human_faithful" not in rows[0]:
121
+ return {}
122
+ faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))]
123
+ gold = [r["human_faithful"] for r in rows]
124
+ out = {}
125
+ for k in faith_keys:
126
+ vec = [r["metrics"][k] for r in rows]
127
+ r, (lo, hi), p = corr_ci(vec, gold, method=cfg.correlation_method,
128
+ n_boot=cfg.n_boot, ci=cfg.ci)
129
+ out[k] = dict(r=r, ci=[lo, hi], p=p)
130
+ return out
131
+
132
+
133
+ def rq3_error_propagation(rows):
134
+ if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]:
135
+ return {}
136
+ ret_err = [r["retrieval_error"] for r in rows]
137
+ halluc = [r["hallucination"] for r in rows]
138
+ cond = conditional_failure_rate(ret_err, halluc)
139
+ chi2 = chi2_error_propagation(ret_err, halluc)
140
+ return {"conditional": cond, "chi2": chi2}
141
+
142
+
143
+ def rq4_robustness(orig_rows, pert_rows):
144
+ if pert_rows is None:
145
+ return {}
146
+ metrics = orig_rows[0]["metrics"].keys()
147
+ out = {}
148
+ for m in metrics:
149
+ d, eff = delta_metric(
150
+ [r["metrics"][m] for r in orig_rows],
151
+ [r["metrics"][m] for r in pert_rows],
152
+ )
153
+ out[m] = dict(delta=d, cohen_d=eff)
154
+ return out
155
+
156
+
157
+ # ─────────────────────────── plotting helpers ───────────────────────────────
158
+ def scatter_mrr_vs_correct(rows, path: Path):
159
+ x = [r["metrics"].get("mrr", np.nan) for r in rows]
160
+ y = [1 if r.get("human_correct") else 0 for r in rows]
161
+ plt.figure()
162
+ plt.scatter(x, y, alpha=0.5)
163
+ plt.xlabel("MRR"); plt.ylabel("Correct (1)")
164
+ plt.title("MRR vs. Human Correctness")
165
+ plt.tight_layout(); plt.savefig(path); plt.close()
166
+
167
+
168
+ # ────────────────────────────────── main ────────────────────────────────────
169
+ def main(argv: list[str] | None = None) -> None:
170
+ ap = argparse.ArgumentParser()
171
+ ap.add_argument("--configs", nargs="+", type=Path, required=True,
172
+ help="One or more YAML configs; if one, compared against prior runs.")
173
+ ap.add_argument("--datasets", nargs="+", type=Path, required=True)
174
+ ap.add_argument("--outdir", type=Path, default=Path("outputs/grid"))
175
+ ap.add_argument("--plots", action="store_true")
176
+ ap.add_argument("--perturbed-suffix", default="_pert",
177
+ help="If dataset perturbed version exists (name+suffix.jsonl) it's used for RQ4.")
178
+ args = ap.parse_args(argv)
179
+
180
+ init_logging(log_dir=args.outdir / "logs", level="INFO")
181
+ log = logging.getLogger("grid")
182
+
183
+ for dataset in args.datasets:
184
+ log.info("Dataset: %s", dataset.name)
185
+ queries = read_jsonl(dataset)
186
+ pert_path = dataset.with_stem(dataset.stem + args.perturbed_suffix)
187
+ pert_rows = read_jsonl(pert_path) if pert_path.exists() else None
188
+
189
+ # discover historical configs to compare against if incremental mode
190
+ hist_dirs = (args.outdir / dataset.stem).glob("*") if len(args.configs) == 1 else []
191
+ historical = {d.name: read_jsonl(d / "results.jsonl") for d in hist_dirs if d.is_dir()}
192
+
193
+ for cfg_yaml in args.configs:
194
+ cfg_name = cfg_yaml.stem
195
+ log.info(" Config: %s", cfg_name)
196
+ cfg = load_pipeline_config(cfg_yaml)
197
+ pipe = RAGPipeline(cfg)
198
+
199
+ # skip if results already exist
200
+ run_dir = args.outdir / dataset.stem / cfg_name
201
+ if (run_dir / "results.jsonl").exists():
202
+ log.info(" results already present – loading.")
203
+ rows = read_jsonl(run_dir / "results.jsonl")
204
+ else:
205
+ rows = [pipe.run(q["question"]) | q for q in queries]
206
+ write_jsonl(run_dir / "results.jsonl", rows)
207
+
208
+ # aggregates & RQ1–4
209
+ save_yaml(run_dir / "aggregates.yaml", agg_mean(rows))
210
+ save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows, cfg.stats))
211
+ save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows, cfg.stats))
212
+ save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows))
213
+
214
+ if pert_rows:
215
+ save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows))
216
+
217
+ if args.plots:
218
+ scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png")
219
+
220
+ historical[cfg_name] = rows # include current for pairwise tests
221
+
222
+ # pairwise Wilcoxon on rag_score
223
+ if len(historical) > 1:
224
+ pairs = {}
225
+ names = list(historical)
226
+ for a, b in itertools.combinations(names, 2):
227
+ x = [r["metrics"]["rag_score"] for r in historical[a]]
228
+ y = [r["metrics"]["rag_score"] for r in historical[b]]
229
+ _, p = wilcoxon_signed_rank(x, y)
230
+ pairs[f"{a}~{b}"] = p
231
+ save_yaml(args.outdir / dataset.stem / "wilcoxon_rag_raw.yaml", pairs)
232
+ save_yaml(args.outdir / dataset.stem / "wilcoxon_rag_holm.yaml",
233
+ holm_bonferroni(pairs))
234
+
235
+ log.info(" Pairwise rag_score significance stored (Holm adjusted).")
236
+
237
+
238
+ if __name__ == "__main__":
239
+ main()