Rom89823974978 commited on
Commit
79bdbbe
·
1 Parent(s): f868144

Resolved tests issues

Browse files
evaluation/config.py CHANGED
@@ -38,7 +38,7 @@ class RetrieverConfig:
38
  index_path: Optional[Union[str, Path]] = None # alias for bm25_index
39
 
40
  # Specific to BM25
41
- bm25_index: Optional[Union[str, Path]] = None
42
  doc_store: Optional[Union[str, Path]] = None
43
 
44
  # For dense-only
@@ -53,7 +53,7 @@ class RetrieverConfig:
53
  def __post_init__(self):
54
  # If index_path is provided (legacy), use it as bm25_index
55
  if self.index_path:
56
- self.bm25_index = self.index_path
57
 
58
 
59
  @dataclass
 
38
  index_path: Optional[Union[str, Path]] = None # alias for bm25_index
39
 
40
  # Specific to BM25
41
+ bm25_idx: Optional[Union[str, Path]] = None
42
  doc_store: Optional[Union[str, Path]] = None
43
 
44
  # For dense-only
 
53
  def __post_init__(self):
54
  # If index_path is provided (legacy), use it as bm25_index
55
  if self.index_path:
56
+ self.bm25_idx = self.index_path
57
 
58
 
59
  @dataclass
evaluation/metrics/composite.py CHANGED
@@ -5,12 +5,12 @@ from typing import Mapping
5
  import math
6
 
7
 
8
- def harmonic_mean(scores: Mapping[str, float], eps: float = 1e-6) -> float:
9
  """Compute the harmonic mean of positive scores."""
10
  if not scores:
11
  return 0.0
12
- inv_sum = sum(1.0 / (v + eps) for v in scores.values() if v > 0)
13
- return len(scores) / inv_sum if inv_sum else 0.0
14
 
15
 
16
  def rag_score(scores: Mapping[str, float]) -> float:
 
5
  import math
6
 
7
 
8
+ def harmonic_mean(scores: Mapping[str, float]) -> float:
9
  """Compute the harmonic mean of positive scores."""
10
  if not scores:
11
  return 0.0
12
+ inv_sum = sum(1.0 / (v) for v in scores.values() if v > 0)
13
+ return len(scores) / inv_sum if inv_sum and inv_sum != 0 else 0.0
14
 
15
 
16
  def rag_score(scores: Mapping[str, float]) -> float:
evaluation/pipeline.py CHANGED
@@ -22,29 +22,18 @@ class RAGPipeline:
22
  self.generator = HFGenerator(
23
  model_name=cfg.generator.model_name, device=cfg.generator.device
24
  )
25
- self.reranker = (
26
- CrossEncoderReranker(
27
  cfg.reranker.model_name,
28
  device=cfg.reranker.device,
29
- max_len=cfg.reranker.max_length,
30
  )
31
- if cfg.reranker.enable
32
- else None
33
- )
34
 
35
  # ---------------------------------------------------------------------
36
  # Public API
37
  # ---------------------------------------------------------------------
38
- def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
39
- """Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
40
- results: list[dict[str, Any]] = []
41
- for entry in queries:
42
- q = entry.get("question", "")
43
- doc_id = entry.get("id")
44
- answer = self.run_query(q)
45
- results.append({"id": doc_id, "question": q, "answer": answer})
46
- return results
47
-
48
  def run(self, question: str) -> Dict[str, Any]:
49
  """Retrieve context and generate answer."""
50
  logger.info("Question: %s", question)
@@ -58,22 +47,40 @@ class RAGPipeline:
58
 
59
  __call__ = run # alias
60
 
 
 
 
 
 
 
 
 
 
61
  # ---------------------------------------------------------------------
62
  # Private helpers
63
  # ---------------------------------------------------------------------
64
  def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
65
- name = cfg.retriever.name
66
  r=cfg.retriever
 
67
  if name == "bm25":
68
- return bm25.BM25Retriever(
69
- index_path=str(r.bm25_index), doc_store_path=str(r.doc_store))
70
  if name == "dense":
71
- return dense.DenseRetriever(faiss_index=str(r.faiss_index),doc_store=r.doc_store,model_name=r.model_name,embedder_cache=r.embedder_cache,device=r.device)
 
 
 
 
 
 
72
  if name == "hybrid":
73
  return hybrid.HybridRetriever(
74
- bm25_idx=str(r.bm25_index),
75
- dense_idx=str(r.faiss_index),
 
76
  alpha=r.alpha,
 
 
 
77
  )
78
  raise ValueError(f"Unsupported retriever '{name}'")
79
 
 
22
  self.generator = HFGenerator(
23
  model_name=cfg.generator.model_name, device=cfg.generator.device
24
  )
25
+ if cfg.reranker.enable:
26
+ self.reranker = CrossEncoderReranker(
27
  cfg.reranker.model_name,
28
  device=cfg.reranker.device,
29
+ max_length=cfg.reranker.max_length,
30
  )
31
+ else:
32
+ self.reranker = None
 
33
 
34
  # ---------------------------------------------------------------------
35
  # Public API
36
  # ---------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
37
  def run(self, question: str) -> Dict[str, Any]:
38
  """Retrieve context and generate answer."""
39
  logger.info("Question: %s", question)
 
47
 
48
  __call__ = run # alias
49
 
50
+ def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
51
+ """Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
52
+ results: list[dict[str, Any]] = []
53
+ for entry in queries:
54
+ q = entry.get("question", "")
55
+ doc_id = entry.get("id")
56
+ answer = self.run(q)
57
+ results.append({"id": doc_id, "question": q, "answer": answer})
58
+ return results
59
  # ---------------------------------------------------------------------
60
  # Private helpers
61
  # ---------------------------------------------------------------------
62
  def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
 
63
  r=cfg.retriever
64
+ name = r.name
65
  if name == "bm25":
66
+ return bm25.BM25Retriever(bm25_idx=str(r.bm25_index), doc_store_path=str(r.doc_store))
 
67
  if name == "dense":
68
+ return dense.DenseRetriever(
69
+ faiss_index=str(r.faiss_index),
70
+ doc_store=str(r.doc_store),
71
+ model_name=r.model_name,
72
+ embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
73
+ device=r.device,
74
+ )
75
  if name == "hybrid":
76
  return hybrid.HybridRetriever(
77
+ str(r.bm25_index),
78
+ str(r.faiss_index),
79
+ doc_store=str(r.doc_store),
80
  alpha=r.alpha,
81
+ model_name=r.model_name,
82
+ embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
83
+ device=r.device,
84
  )
85
  raise ValueError(f"Unsupported retriever '{name}'")
86
 
evaluation/retrievers/bm25.py CHANGED
@@ -18,11 +18,11 @@ class BM25Retriever(Retriever):
18
 
19
  def __init__(
20
  self,
21
- index_path: str | None,
22
  doc_store_path: str | None = None,
23
  threads: int = 1,
24
  ):
25
- if index_path is None:
26
  raise ValueError("BM25 retriever requires a path to a Pyserini index.")
27
 
28
  # ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
@@ -39,24 +39,24 @@ class BM25Retriever(Retriever):
39
  )
40
  SimpleSearcher = None
41
 
42
- self.index_path = index_path
43
  self.doc_store_path = doc_store_path
44
  self.threads = threads
45
  self.searcher = None
46
 
47
  # ❷ If the index folder does not exist, attempt to build it from doc_store_path
48
- if not Path(index_path).exists():
49
  if doc_store_path is None:
50
  logger.warning(
51
  "BM25 index %s not found and no `doc_store_path` supplied. "
52
  "BM25Retriever.retrieve() will return no hits.",
53
- index_path,
54
  )
55
  else:
56
  try:
57
  logger.info(
58
  "BM25 index %s missing – building from %s ...",
59
- index_path,
60
  doc_store_path,
61
  )
62
  self._build_index(Path(doc_store_path), index_path, threads)
@@ -78,7 +78,7 @@ class BM25Retriever(Retriever):
78
  logger.warning(
79
  "Failed to instantiate SimpleSearcher on %s (%s). "
80
  "BM25Retriever.retrieve() will return no hits.",
81
- index_path,
82
  e,
83
  )
84
  self.searcher = None
 
18
 
19
  def __init__(
20
  self,
21
+ bm25_idx: str | None,
22
  doc_store_path: str | None = None,
23
  threads: int = 1,
24
  ):
25
+ if bm25_idx is None:
26
  raise ValueError("BM25 retriever requires a path to a Pyserini index.")
27
 
28
  # ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
 
39
  )
40
  SimpleSearcher = None
41
 
42
+ self.bm25_idx = bm25_idx
43
  self.doc_store_path = doc_store_path
44
  self.threads = threads
45
  self.searcher = None
46
 
47
  # ❷ If the index folder does not exist, attempt to build it from doc_store_path
48
+ if not Path(bm25_idx).exists():
49
  if doc_store_path is None:
50
  logger.warning(
51
  "BM25 index %s not found and no `doc_store_path` supplied. "
52
  "BM25Retriever.retrieve() will return no hits.",
53
+ bm25_idx,
54
  )
55
  else:
56
  try:
57
  logger.info(
58
  "BM25 index %s missing – building from %s ...",
59
+ bm25_idx,
60
  doc_store_path,
61
  )
62
  self._build_index(Path(doc_store_path), index_path, threads)
 
78
  logger.warning(
79
  "Failed to instantiate SimpleSearcher on %s (%s). "
80
  "BM25Retriever.retrieve() will return no hits.",
81
+ bm25_idx,
82
  e,
83
  )
84
  self.searcher = None
evaluation/retrievers/hybrid.py CHANGED
@@ -1,11 +1,8 @@
1
- """Hybrid retriever that combines sparse and dense scores (linear sum)."""
2
-
3
  from __future__ import annotations
4
- from typing import List, Optional
5
- from pathlib import Path
6
  import logging
 
7
 
8
- from .base import Retriever, Context
9
  from .bm25 import BM25Retriever
10
  from .dense import DenseRetriever
11
 
@@ -13,32 +10,63 @@ logger = logging.getLogger(__name__)
13
 
14
 
15
  class HybridRetriever(Retriever):
16
- """Combine BM25 and Dense retrievers by score normalisation and sum."""
17
 
18
- def __init__(self, bm25_idx: str, faiss_idx: str, *, doc_store: str, alpha: float = 0.5, model_name: str = "sentence-transformers/all-MiniLM-L6-v2", embedder_cache: Optional[str] = None, device: str = "cpu"):
19
- self.bm25 = BM25Retriever(index_path=bm25_idx, doc_store_path=doc_store)
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self.dense = DenseRetriever(
21
  faiss_index=faiss_idx,
22
  doc_store=doc_store,
23
  model_name=model_name,
24
  embedder_cache=embedder_cache,
25
- device=device)
 
 
26
  if not 0 <= alpha <= 1:
27
- raise ValueError("alpha must be in [0, 1]")
28
  self.alpha = alpha
29
 
30
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
31
- sparse_ctxs = {c.id: c for c in self.sparse.retrieve(query, top_k=top_k)}
32
- dense_ctxs = {c.id: c for c in self.dense.retrieve(query, top_k=top_k)}
 
33
 
34
- ids = list(set(sparse_ctxs) | set(dense_ctxs))
 
 
 
 
 
35
  merged: List[Context] = []
36
- for doc_id in ids:
37
- sparse_score = sparse_ctxs.get(doc_id, Context(doc_id, "", 0.0)).score
38
- dense_score = dense_ctxs.get(doc_id, Context(doc_id, "", 0.0)).score
39
- score = self.alpha * sparse_score + (1 - self.alpha) * dense_score
40
- text = sparse_ctxs.get(doc_id, dense_ctxs.get(doc_id)).text # type: ignore
41
- merged.append(Context(id=doc_id, text=text, score=score))
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  merged.sort(key=lambda c: c.score, reverse=True)
44
  return merged[:top_k]
 
 
 
1
  from __future__ import annotations
 
 
2
  import logging
3
+ from typing import List, Optional
4
 
5
+ from .base import Context, Retriever
6
  from .bm25 import BM25Retriever
7
  from .dense import DenseRetriever
8
 
 
10
 
11
 
12
  class HybridRetriever(Retriever):
13
+ """Combine BM25 and Dense retrievers by normalising and summing scores."""
14
 
15
+ def __init__(
16
+ self,
17
+ bm25_idx: str,
18
+ faiss_idx: str,
19
+ doc_store: str,
20
+ *,
21
+ alpha: float = 0.5,
22
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
23
+ embedder_cache: Optional[str] = None,
24
+ device: str = "cpu",
25
+ ):
26
+ # 1) BM25 retriever
27
+ self.bm25 = BM25Retriever(bm25_idx, doc_store_path=doc_store)
28
+
29
+ # 2) Dense retriever
30
  self.dense = DenseRetriever(
31
  faiss_index=faiss_idx,
32
  doc_store=doc_store,
33
  model_name=model_name,
34
  embedder_cache=embedder_cache,
35
+ device=device,
36
+ )
37
+
38
  if not 0 <= alpha <= 1:
39
+ raise ValueError("alpha must be in [0, 1]")
40
  self.alpha = alpha
41
 
42
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
43
+ # 1) Get sparse hits
44
+ sparse_hits = self.bm25.retrieve(query, top_k=top_k)
45
+ sparse_dict = {ctx.id: ctx for ctx in sparse_hits}
46
 
47
+ # 2) Get dense hits
48
+ dense_hits = self.dense.retrieve(query, top_k=top_k)
49
+ dense_dict = {ctx.id: ctx for ctx in dense_hits}
50
+
51
+ # 3) Union of all IDs
52
+ all_ids = set(sparse_dict) | set(dense_dict)
53
  merged: List[Context] = []
 
 
 
 
 
 
54
 
55
+ for doc_id in all_ids:
56
+ s_score = sparse_dict.get(doc_id, Context(doc_id, "", 0.0)).score
57
+ d_score = dense_dict.get(doc_id, Context(doc_id, "", 0.0)).score
58
+
59
+ combined_score = self.alpha * s_score + (1 - self.alpha) * d_score
60
+
61
+ # Prefer the text from whichever retriever has this doc_id present;
62
+ # if only one side has it, grab that text.
63
+ if doc_id in sparse_dict:
64
+ text = sparse_dict[doc_id].text
65
+ else:
66
+ text = dense_dict[doc_id].text
67
+
68
+ merged.append(Context(id=doc_id, text=text, score=combined_score))
69
+
70
+ # 4) Sort by score descending
71
  merged.sort(key=lambda c: c.score, reverse=True)
72
  return merged[:top_k]
evaluation/stats/robustness.py CHANGED
@@ -79,4 +79,5 @@ def chi2_error_propagation(
79
  chi2, p, dof, expected = chi2_contingency(table)
80
  return dict(chi2=chi2, p=p, dof=dof, expected=expected, table=table)
81
  except ValueError:
82
- return dict(chi2=0.0, p=1.0, dof=dof, expected=expected, table=table)
 
 
79
  chi2, p, dof, expected = chi2_contingency(table)
80
  return dict(chi2=chi2, p=p, dof=dof, expected=expected, table=table)
81
  except ValueError:
82
+ default_expected = [[0, 0], [0, 0]]
83
+ return dict(chi2=0.0, p=1.0, dof=0, expected=default_expected, table=table)
evaluation/stats/significance.py CHANGED
@@ -41,4 +41,7 @@ def holm_bonferroni(pvalues: Mapping[str, float]) -> Mapping[str, float]:
41
 
42
  def delta_metric(base: Sequence[float], new: Sequence[float]) -> list[float]:
43
  """Compute per‐element differences `new[i] - base[i]` as a list of floats."""
44
- return [float(n - b) for b, n in zip(base, new)]
 
 
 
 
41
 
42
  def delta_metric(base: Sequence[float], new: Sequence[float]) -> list[float]:
43
  """Compute per‐element differences `new[i] - base[i]` as a list of floats."""
44
+ diffs: list[float] = []
45
+ for b, n in zip(base, new):
46
+ diffs.append(float(n - b))
47
+ return diffs
tests/test_metrics.py CHANGED
@@ -25,7 +25,6 @@ def test_retrieval_metrics_simple():
25
  assert recall_at_k(retrieved, relevant, 2) == pytest.approx(1 / 3, rel=1e-6)
26
  assert recall_at_k(retrieved, relevant, 4) == pytest.approx(2 / 3, rel=1e-6)
27
  assert mean_reciprocal_rank(retrieved, relevant) == pytest.approx(0.5, rel=1e-6)
28
- # AP = (1/2 + 2/4)/3 = 1/3
29
  assert average_precision(retrieved, relevant) == pytest.approx(1 / 3, rel=1e-6)
30
 
31
 
 
25
  assert recall_at_k(retrieved, relevant, 2) == pytest.approx(1 / 3, rel=1e-6)
26
  assert recall_at_k(retrieved, relevant, 4) == pytest.approx(2 / 3, rel=1e-6)
27
  assert mean_reciprocal_rank(retrieved, relevant) == pytest.approx(0.5, rel=1e-6)
 
28
  assert average_precision(retrieved, relevant) == pytest.approx(1 / 3, rel=1e-6)
29
 
30