Spaces:
Sleeping
Sleeping
Commit
·
79bdbbe
1
Parent(s):
f868144
Resolved tests issues
Browse files- evaluation/config.py +2 -2
- evaluation/metrics/composite.py +3 -3
- evaluation/pipeline.py +29 -22
- evaluation/retrievers/bm25.py +7 -7
- evaluation/retrievers/hybrid.py +47 -19
- evaluation/stats/robustness.py +2 -1
- evaluation/stats/significance.py +4 -1
- tests/test_metrics.py +0 -1
evaluation/config.py
CHANGED
|
@@ -38,7 +38,7 @@ class RetrieverConfig:
|
|
| 38 |
index_path: Optional[Union[str, Path]] = None # alias for bm25_index
|
| 39 |
|
| 40 |
# Specific to BM25
|
| 41 |
-
|
| 42 |
doc_store: Optional[Union[str, Path]] = None
|
| 43 |
|
| 44 |
# For dense-only
|
|
@@ -53,7 +53,7 @@ class RetrieverConfig:
|
|
| 53 |
def __post_init__(self):
|
| 54 |
# If index_path is provided (legacy), use it as bm25_index
|
| 55 |
if self.index_path:
|
| 56 |
-
self.
|
| 57 |
|
| 58 |
|
| 59 |
@dataclass
|
|
|
|
| 38 |
index_path: Optional[Union[str, Path]] = None # alias for bm25_index
|
| 39 |
|
| 40 |
# Specific to BM25
|
| 41 |
+
bm25_idx: Optional[Union[str, Path]] = None
|
| 42 |
doc_store: Optional[Union[str, Path]] = None
|
| 43 |
|
| 44 |
# For dense-only
|
|
|
|
| 53 |
def __post_init__(self):
|
| 54 |
# If index_path is provided (legacy), use it as bm25_index
|
| 55 |
if self.index_path:
|
| 56 |
+
self.bm25_idx = self.index_path
|
| 57 |
|
| 58 |
|
| 59 |
@dataclass
|
evaluation/metrics/composite.py
CHANGED
|
@@ -5,12 +5,12 @@ from typing import Mapping
|
|
| 5 |
import math
|
| 6 |
|
| 7 |
|
| 8 |
-
def harmonic_mean(scores: Mapping[str, float]
|
| 9 |
"""Compute the harmonic mean of positive scores."""
|
| 10 |
if not scores:
|
| 11 |
return 0.0
|
| 12 |
-
inv_sum = sum(1.0 / (v
|
| 13 |
-
return len(scores) / inv_sum if inv_sum else 0.0
|
| 14 |
|
| 15 |
|
| 16 |
def rag_score(scores: Mapping[str, float]) -> float:
|
|
|
|
| 5 |
import math
|
| 6 |
|
| 7 |
|
| 8 |
+
def harmonic_mean(scores: Mapping[str, float]) -> float:
|
| 9 |
"""Compute the harmonic mean of positive scores."""
|
| 10 |
if not scores:
|
| 11 |
return 0.0
|
| 12 |
+
inv_sum = sum(1.0 / (v) for v in scores.values() if v > 0)
|
| 13 |
+
return len(scores) / inv_sum if inv_sum and inv_sum != 0 else 0.0
|
| 14 |
|
| 15 |
|
| 16 |
def rag_score(scores: Mapping[str, float]) -> float:
|
evaluation/pipeline.py
CHANGED
|
@@ -22,29 +22,18 @@ class RAGPipeline:
|
|
| 22 |
self.generator = HFGenerator(
|
| 23 |
model_name=cfg.generator.model_name, device=cfg.generator.device
|
| 24 |
)
|
| 25 |
-
|
| 26 |
-
CrossEncoderReranker(
|
| 27 |
cfg.reranker.model_name,
|
| 28 |
device=cfg.reranker.device,
|
| 29 |
-
|
| 30 |
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
)
|
| 34 |
|
| 35 |
# ---------------------------------------------------------------------
|
| 36 |
# Public API
|
| 37 |
# ---------------------------------------------------------------------
|
| 38 |
-
def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 39 |
-
"""Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
|
| 40 |
-
results: list[dict[str, Any]] = []
|
| 41 |
-
for entry in queries:
|
| 42 |
-
q = entry.get("question", "")
|
| 43 |
-
doc_id = entry.get("id")
|
| 44 |
-
answer = self.run_query(q)
|
| 45 |
-
results.append({"id": doc_id, "question": q, "answer": answer})
|
| 46 |
-
return results
|
| 47 |
-
|
| 48 |
def run(self, question: str) -> Dict[str, Any]:
|
| 49 |
"""Retrieve context and generate answer."""
|
| 50 |
logger.info("Question: %s", question)
|
|
@@ -58,22 +47,40 @@ class RAGPipeline:
|
|
| 58 |
|
| 59 |
__call__ = run # alias
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# ---------------------------------------------------------------------
|
| 62 |
# Private helpers
|
| 63 |
# ---------------------------------------------------------------------
|
| 64 |
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
|
| 65 |
-
name = cfg.retriever.name
|
| 66 |
r=cfg.retriever
|
|
|
|
| 67 |
if name == "bm25":
|
| 68 |
-
|
| 69 |
-
index_path=str(r.bm25_index), doc_store_path=str(r.doc_store))
|
| 70 |
if name == "dense":
|
| 71 |
-
return dense.DenseRetriever(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if name == "hybrid":
|
| 73 |
return hybrid.HybridRetriever(
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
alpha=r.alpha,
|
|
|
|
|
|
|
|
|
|
| 77 |
)
|
| 78 |
raise ValueError(f"Unsupported retriever '{name}'")
|
| 79 |
|
|
|
|
| 22 |
self.generator = HFGenerator(
|
| 23 |
model_name=cfg.generator.model_name, device=cfg.generator.device
|
| 24 |
)
|
| 25 |
+
if cfg.reranker.enable:
|
| 26 |
+
self.reranker = CrossEncoderReranker(
|
| 27 |
cfg.reranker.model_name,
|
| 28 |
device=cfg.reranker.device,
|
| 29 |
+
max_length=cfg.reranker.max_length,
|
| 30 |
)
|
| 31 |
+
else:
|
| 32 |
+
self.reranker = None
|
|
|
|
| 33 |
|
| 34 |
# ---------------------------------------------------------------------
|
| 35 |
# Public API
|
| 36 |
# ---------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def run(self, question: str) -> Dict[str, Any]:
|
| 38 |
"""Retrieve context and generate answer."""
|
| 39 |
logger.info("Question: %s", question)
|
|
|
|
| 47 |
|
| 48 |
__call__ = run # alias
|
| 49 |
|
| 50 |
+
def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
| 51 |
+
"""Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
|
| 52 |
+
results: list[dict[str, Any]] = []
|
| 53 |
+
for entry in queries:
|
| 54 |
+
q = entry.get("question", "")
|
| 55 |
+
doc_id = entry.get("id")
|
| 56 |
+
answer = self.run(q)
|
| 57 |
+
results.append({"id": doc_id, "question": q, "answer": answer})
|
| 58 |
+
return results
|
| 59 |
# ---------------------------------------------------------------------
|
| 60 |
# Private helpers
|
| 61 |
# ---------------------------------------------------------------------
|
| 62 |
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
|
|
|
|
| 63 |
r=cfg.retriever
|
| 64 |
+
name = r.name
|
| 65 |
if name == "bm25":
|
| 66 |
+
return bm25.BM25Retriever(bm25_idx=str(r.bm25_index), doc_store_path=str(r.doc_store))
|
|
|
|
| 67 |
if name == "dense":
|
| 68 |
+
return dense.DenseRetriever(
|
| 69 |
+
faiss_index=str(r.faiss_index),
|
| 70 |
+
doc_store=str(r.doc_store),
|
| 71 |
+
model_name=r.model_name,
|
| 72 |
+
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
|
| 73 |
+
device=r.device,
|
| 74 |
+
)
|
| 75 |
if name == "hybrid":
|
| 76 |
return hybrid.HybridRetriever(
|
| 77 |
+
str(r.bm25_index),
|
| 78 |
+
str(r.faiss_index),
|
| 79 |
+
doc_store=str(r.doc_store),
|
| 80 |
alpha=r.alpha,
|
| 81 |
+
model_name=r.model_name,
|
| 82 |
+
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
|
| 83 |
+
device=r.device,
|
| 84 |
)
|
| 85 |
raise ValueError(f"Unsupported retriever '{name}'")
|
| 86 |
|
evaluation/retrievers/bm25.py
CHANGED
|
@@ -18,11 +18,11 @@ class BM25Retriever(Retriever):
|
|
| 18 |
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
-
|
| 22 |
doc_store_path: str | None = None,
|
| 23 |
threads: int = 1,
|
| 24 |
):
|
| 25 |
-
if
|
| 26 |
raise ValueError("BM25 retriever requires a path to a Pyserini index.")
|
| 27 |
|
| 28 |
# ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
|
|
@@ -39,24 +39,24 @@ class BM25Retriever(Retriever):
|
|
| 39 |
)
|
| 40 |
SimpleSearcher = None
|
| 41 |
|
| 42 |
-
self.
|
| 43 |
self.doc_store_path = doc_store_path
|
| 44 |
self.threads = threads
|
| 45 |
self.searcher = None
|
| 46 |
|
| 47 |
# ❷ If the index folder does not exist, attempt to build it from doc_store_path
|
| 48 |
-
if not Path(
|
| 49 |
if doc_store_path is None:
|
| 50 |
logger.warning(
|
| 51 |
"BM25 index %s not found and no `doc_store_path` supplied. "
|
| 52 |
"BM25Retriever.retrieve() will return no hits.",
|
| 53 |
-
|
| 54 |
)
|
| 55 |
else:
|
| 56 |
try:
|
| 57 |
logger.info(
|
| 58 |
"BM25 index %s missing – building from %s ...",
|
| 59 |
-
|
| 60 |
doc_store_path,
|
| 61 |
)
|
| 62 |
self._build_index(Path(doc_store_path), index_path, threads)
|
|
@@ -78,7 +78,7 @@ class BM25Retriever(Retriever):
|
|
| 78 |
logger.warning(
|
| 79 |
"Failed to instantiate SimpleSearcher on %s (%s). "
|
| 80 |
"BM25Retriever.retrieve() will return no hits.",
|
| 81 |
-
|
| 82 |
e,
|
| 83 |
)
|
| 84 |
self.searcher = None
|
|
|
|
| 18 |
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
+
bm25_idx: str | None,
|
| 22 |
doc_store_path: str | None = None,
|
| 23 |
threads: int = 1,
|
| 24 |
):
|
| 25 |
+
if bm25_idx is None:
|
| 26 |
raise ValueError("BM25 retriever requires a path to a Pyserini index.")
|
| 27 |
|
| 28 |
# ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
|
|
|
|
| 39 |
)
|
| 40 |
SimpleSearcher = None
|
| 41 |
|
| 42 |
+
self.bm25_idx = bm25_idx
|
| 43 |
self.doc_store_path = doc_store_path
|
| 44 |
self.threads = threads
|
| 45 |
self.searcher = None
|
| 46 |
|
| 47 |
# ❷ If the index folder does not exist, attempt to build it from doc_store_path
|
| 48 |
+
if not Path(bm25_idx).exists():
|
| 49 |
if doc_store_path is None:
|
| 50 |
logger.warning(
|
| 51 |
"BM25 index %s not found and no `doc_store_path` supplied. "
|
| 52 |
"BM25Retriever.retrieve() will return no hits.",
|
| 53 |
+
bm25_idx,
|
| 54 |
)
|
| 55 |
else:
|
| 56 |
try:
|
| 57 |
logger.info(
|
| 58 |
"BM25 index %s missing – building from %s ...",
|
| 59 |
+
bm25_idx,
|
| 60 |
doc_store_path,
|
| 61 |
)
|
| 62 |
self._build_index(Path(doc_store_path), index_path, threads)
|
|
|
|
| 78 |
logger.warning(
|
| 79 |
"Failed to instantiate SimpleSearcher on %s (%s). "
|
| 80 |
"BM25Retriever.retrieve() will return no hits.",
|
| 81 |
+
bm25_idx,
|
| 82 |
e,
|
| 83 |
)
|
| 84 |
self.searcher = None
|
evaluation/retrievers/hybrid.py
CHANGED
|
@@ -1,11 +1,8 @@
|
|
| 1 |
-
"""Hybrid retriever that combines sparse and dense scores (linear sum)."""
|
| 2 |
-
|
| 3 |
from __future__ import annotations
|
| 4 |
-
from typing import List, Optional
|
| 5 |
-
from pathlib import Path
|
| 6 |
import logging
|
|
|
|
| 7 |
|
| 8 |
-
from .base import
|
| 9 |
from .bm25 import BM25Retriever
|
| 10 |
from .dense import DenseRetriever
|
| 11 |
|
|
@@ -13,32 +10,63 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
|
| 14 |
|
| 15 |
class HybridRetriever(Retriever):
|
| 16 |
-
"""Combine BM25 and Dense retrievers by
|
| 17 |
|
| 18 |
-
def __init__(
|
| 19 |
-
self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.dense = DenseRetriever(
|
| 21 |
faiss_index=faiss_idx,
|
| 22 |
doc_store=doc_store,
|
| 23 |
model_name=model_name,
|
| 24 |
embedder_cache=embedder_cache,
|
| 25 |
-
device=device
|
|
|
|
|
|
|
| 26 |
if not 0 <= alpha <= 1:
|
| 27 |
-
raise ValueError("alpha must be in [0,
|
| 28 |
self.alpha = alpha
|
| 29 |
|
| 30 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
merged: List[Context] = []
|
| 36 |
-
for doc_id in ids:
|
| 37 |
-
sparse_score = sparse_ctxs.get(doc_id, Context(doc_id, "", 0.0)).score
|
| 38 |
-
dense_score = dense_ctxs.get(doc_id, Context(doc_id, "", 0.0)).score
|
| 39 |
-
score = self.alpha * sparse_score + (1 - self.alpha) * dense_score
|
| 40 |
-
text = sparse_ctxs.get(doc_id, dense_ctxs.get(doc_id)).text # type: ignore
|
| 41 |
-
merged.append(Context(id=doc_id, text=text, score=score))
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
merged.sort(key=lambda c: c.score, reverse=True)
|
| 44 |
return merged[:top_k]
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
|
|
|
|
|
|
| 2 |
import logging
|
| 3 |
+
from typing import List, Optional
|
| 4 |
|
| 5 |
+
from .base import Context, Retriever
|
| 6 |
from .bm25 import BM25Retriever
|
| 7 |
from .dense import DenseRetriever
|
| 8 |
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class HybridRetriever(Retriever):
|
| 13 |
+
"""Combine BM25 and Dense retrievers by normalising and summing scores."""
|
| 14 |
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
bm25_idx: str,
|
| 18 |
+
faiss_idx: str,
|
| 19 |
+
doc_store: str,
|
| 20 |
+
*,
|
| 21 |
+
alpha: float = 0.5,
|
| 22 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 23 |
+
embedder_cache: Optional[str] = None,
|
| 24 |
+
device: str = "cpu",
|
| 25 |
+
):
|
| 26 |
+
# 1) BM25 retriever
|
| 27 |
+
self.bm25 = BM25Retriever(bm25_idx, doc_store_path=doc_store)
|
| 28 |
+
|
| 29 |
+
# 2) Dense retriever
|
| 30 |
self.dense = DenseRetriever(
|
| 31 |
faiss_index=faiss_idx,
|
| 32 |
doc_store=doc_store,
|
| 33 |
model_name=model_name,
|
| 34 |
embedder_cache=embedder_cache,
|
| 35 |
+
device=device,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
if not 0 <= alpha <= 1:
|
| 39 |
+
raise ValueError("alpha must be in [0, 1]")
|
| 40 |
self.alpha = alpha
|
| 41 |
|
| 42 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 43 |
+
# 1) Get sparse hits
|
| 44 |
+
sparse_hits = self.bm25.retrieve(query, top_k=top_k)
|
| 45 |
+
sparse_dict = {ctx.id: ctx for ctx in sparse_hits}
|
| 46 |
|
| 47 |
+
# 2) Get dense hits
|
| 48 |
+
dense_hits = self.dense.retrieve(query, top_k=top_k)
|
| 49 |
+
dense_dict = {ctx.id: ctx for ctx in dense_hits}
|
| 50 |
+
|
| 51 |
+
# 3) Union of all IDs
|
| 52 |
+
all_ids = set(sparse_dict) | set(dense_dict)
|
| 53 |
merged: List[Context] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
for doc_id in all_ids:
|
| 56 |
+
s_score = sparse_dict.get(doc_id, Context(doc_id, "", 0.0)).score
|
| 57 |
+
d_score = dense_dict.get(doc_id, Context(doc_id, "", 0.0)).score
|
| 58 |
+
|
| 59 |
+
combined_score = self.alpha * s_score + (1 - self.alpha) * d_score
|
| 60 |
+
|
| 61 |
+
# Prefer the text from whichever retriever has this doc_id present;
|
| 62 |
+
# if only one side has it, grab that text.
|
| 63 |
+
if doc_id in sparse_dict:
|
| 64 |
+
text = sparse_dict[doc_id].text
|
| 65 |
+
else:
|
| 66 |
+
text = dense_dict[doc_id].text
|
| 67 |
+
|
| 68 |
+
merged.append(Context(id=doc_id, text=text, score=combined_score))
|
| 69 |
+
|
| 70 |
+
# 4) Sort by score descending
|
| 71 |
merged.sort(key=lambda c: c.score, reverse=True)
|
| 72 |
return merged[:top_k]
|
evaluation/stats/robustness.py
CHANGED
|
@@ -79,4 +79,5 @@ def chi2_error_propagation(
|
|
| 79 |
chi2, p, dof, expected = chi2_contingency(table)
|
| 80 |
return dict(chi2=chi2, p=p, dof=dof, expected=expected, table=table)
|
| 81 |
except ValueError:
|
| 82 |
-
|
|
|
|
|
|
| 79 |
chi2, p, dof, expected = chi2_contingency(table)
|
| 80 |
return dict(chi2=chi2, p=p, dof=dof, expected=expected, table=table)
|
| 81 |
except ValueError:
|
| 82 |
+
default_expected = [[0, 0], [0, 0]]
|
| 83 |
+
return dict(chi2=0.0, p=1.0, dof=0, expected=default_expected, table=table)
|
evaluation/stats/significance.py
CHANGED
|
@@ -41,4 +41,7 @@ def holm_bonferroni(pvalues: Mapping[str, float]) -> Mapping[str, float]:
|
|
| 41 |
|
| 42 |
def delta_metric(base: Sequence[float], new: Sequence[float]) -> list[float]:
|
| 43 |
"""Compute per‐element differences `new[i] - base[i]` as a list of floats."""
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def delta_metric(base: Sequence[float], new: Sequence[float]) -> list[float]:
|
| 43 |
"""Compute per‐element differences `new[i] - base[i]` as a list of floats."""
|
| 44 |
+
diffs: list[float] = []
|
| 45 |
+
for b, n in zip(base, new):
|
| 46 |
+
diffs.append(float(n - b))
|
| 47 |
+
return diffs
|
tests/test_metrics.py
CHANGED
|
@@ -25,7 +25,6 @@ def test_retrieval_metrics_simple():
|
|
| 25 |
assert recall_at_k(retrieved, relevant, 2) == pytest.approx(1 / 3, rel=1e-6)
|
| 26 |
assert recall_at_k(retrieved, relevant, 4) == pytest.approx(2 / 3, rel=1e-6)
|
| 27 |
assert mean_reciprocal_rank(retrieved, relevant) == pytest.approx(0.5, rel=1e-6)
|
| 28 |
-
# AP = (1/2 + 2/4)/3 = 1/3
|
| 29 |
assert average_precision(retrieved, relevant) == pytest.approx(1 / 3, rel=1e-6)
|
| 30 |
|
| 31 |
|
|
|
|
| 25 |
assert recall_at_k(retrieved, relevant, 2) == pytest.approx(1 / 3, rel=1e-6)
|
| 26 |
assert recall_at_k(retrieved, relevant, 4) == pytest.approx(2 / 3, rel=1e-6)
|
| 27 |
assert mean_reciprocal_rank(retrieved, relevant) == pytest.approx(0.5, rel=1e-6)
|
|
|
|
| 28 |
assert average_precision(retrieved, relevant) == pytest.approx(1 / 3, rel=1e-6)
|
| 29 |
|
| 30 |
|