Rom89823974978 commited on
Commit
cdf4160
·
1 Parent(s): 8521f60

Updated work

Browse files
evaluation/config.py CHANGED
@@ -15,6 +15,13 @@ class RetrieverConfig:
15
  faiss_index: Optional[Path] = None
16
  doc_store: Optional[Path] = None
17
  device: str = "cpu"
 
 
 
 
 
 
 
18
 
19
 
20
  @dataclass
 
15
  faiss_index: Optional[Path] = None
16
  doc_store: Optional[Path] = None
17
  device: str = "cpu"
18
+
19
+ # hybrid only
20
+ alpha: float = 0.5 # sparse ↔ dense weight
21
+
22
+ # dense-only
23
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
24
+ embedder_cache: Optional[Path] = None
25
 
26
 
27
  @dataclass
evaluation/pipeline.py CHANGED
@@ -44,16 +44,16 @@ class RAGPipeline:
44
  # ---------------------------------------------------------------------
45
  def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
46
  name = cfg.retriever.name
 
47
  if name == "bm25":
48
- return bm25.BM25Retriever(str(cfg.retriever.index_path))
49
  if name == "dense":
50
- return dense.DenseRetriever(str(cfg.retriever.index_path))
51
  if name == "hybrid":
52
- # In a real setting one would supply two paths; simplified here.
53
  return hybrid.HybridRetriever(
54
- bm25_idx=str(cfg.retriever.index_path),
55
- dense_idx=str(cfg.retriever.index_path),
56
- alpha=0.5,
57
  )
58
  raise ValueError(f"Unsupported retriever '{name}'")
59
 
 
44
  # ---------------------------------------------------------------------
45
  def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
46
  name = cfg.retriever.name
47
+ r=cfg.retriever
48
  if name == "bm25":
49
+ return bm25.BM25Retriever(index_path=str(r.bm25_index))
50
  if name == "dense":
51
+ return dense.DenseRetriever(faiss_index=str(r.faiss_index),doc_store=r.doc_store,model_name=r.model_name,embedder_cache=r.embedder_cache,device=r.device)
52
  if name == "hybrid":
 
53
  return hybrid.HybridRetriever(
54
+ bm25_idx=str(r.bm25_index),
55
+ dense_idx=str(r.faiss_index),
56
+ alpha=r.alpha,
57
  )
58
  raise ValueError(f"Unsupported retriever '{name}'")
59
 
evaluation/retrievers/bm25.py CHANGED
@@ -1,30 +1,88 @@
1
- """BM25 sparse retriever backed by Pyserini SimpleSearcher."""
2
 
3
  from __future__ import annotations
4
- from typing import List
5
  import logging
 
 
 
 
6
 
7
  from pyserini.search import SimpleSearcher
8
 
9
  from .base import Retriever, Context
10
 
11
-
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
  class BM25Retriever(Retriever):
16
- """Thin wrapper around Pyserini's BM25 searcher."""
17
 
18
- def __init__(self, index_path: str | None):
 
 
 
 
 
 
19
  if index_path is None:
20
- raise ValueError("BM25 retriever requires a path to a Pyserini index.")
21
- self.searcher = SimpleSearcher(index_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  self.searcher.set_bm25()
23
  logger.info("BM25Retriever initialised with index: %s", index_path)
24
 
 
 
 
25
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
26
  hits = self.searcher.search(query, k=top_k)
27
  return [
28
  Context(id=str(hit.docid), text=hit.raw, score=hit.score) # type: ignore[attr-defined]
29
  for hit in hits
30
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BM25 sparse retriever backed by Pyserini SimpleSearcher, with auto-indexing."""
2
 
3
  from __future__ import annotations
 
4
  import logging
5
+ import os
6
+ import subprocess
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
 
10
  from pyserini.search import SimpleSearcher
11
 
12
  from .base import Retriever, Context
13
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
  class BM25Retriever(Retriever):
18
+ """Pyserini BM25 searcher that will create the Lucene index on-the-fly."""
19
 
20
+ def __init__(
21
+ self,
22
+ index_path: str | os.PathLike | None,
23
+ *,
24
+ doc_store_path: Optional[str | os.PathLike] = None,
25
+ threads: int = 4,
26
+ ):
27
  if index_path is None:
28
+ raise ValueError("`index_path` (directory) is required.")
29
+
30
+ index_path = Path(index_path)
31
+
32
+ # ------------------------------------------------------------------
33
+ # Build index if it does not already exist
34
+ # ------------------------------------------------------------------
35
+ if not index_path.exists():
36
+ if doc_store_path is None:
37
+ raise FileNotFoundError(
38
+ f"BM25 index {index_path} not found and no `doc_store_path` supplied."
39
+ )
40
+ logger.info("BM25 index %s missing – building from %s ...",
41
+ index_path, doc_store_path)
42
+ self._build_index(Path(doc_store_path), index_path, threads)
43
+
44
+ # ------------------------------------------------------------------
45
+ # Searcher
46
+ # ------------------------------------------------------------------
47
+ self.searcher = SimpleSearcher(str(index_path))
48
  self.searcher.set_bm25()
49
  logger.info("BM25Retriever initialised with index: %s", index_path)
50
 
51
+ # ------------------------------------------------------------------ #
52
+ # Public API
53
+ # ------------------------------------------------------------------ #
54
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
55
  hits = self.searcher.search(query, k=top_k)
56
  return [
57
  Context(id=str(hit.docid), text=hit.raw, score=hit.score) # type: ignore[attr-defined]
58
  for hit in hits
59
  ]
60
+
61
+ # ------------------------------------------------------------------ #
62
+ # Helpers
63
+ # ------------------------------------------------------------------ #
64
+ @staticmethod
65
+ def _build_index(
66
+ doc_store: Path,
67
+ index_dir: Path,
68
+ threads: int,
69
+ ):
70
+ """Call Pyserini’s CLI to build a Lucene index from JSONL documents.
71
+
72
+ `doc_store` must be a JSONL file or directory containing JSONL files
73
+ with at least {"id": ..., "text": ...} per line.
74
+ """
75
+ index_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ cmd = [
78
+ "python", "-m", "pyserini.index",
79
+ "-collection", "JsonCollection",
80
+ "-generator", "DefaultLuceneDocumentGenerator",
81
+ "-input", str(doc_store),
82
+ "-index", str(index_dir),
83
+ "-threads", str(threads),
84
+ "-storePositions", "-storeDocvectors", "-storeRaw",
85
+ ]
86
+ logger.info("Running Pyserini indexer: %s", " ".join(cmd))
87
+ subprocess.run(cmd, check=True) # raises if indexing fails
88
+ logger.info("Finished building Lucene index in %s", index_dir)
evaluation/retrievers/dense.py CHANGED
@@ -1,25 +1,115 @@
1
- """Dense vector retriever placeholder (FAISS)."""
2
 
3
  from __future__ import annotations
4
- from typing import List
 
5
  import logging
 
 
 
 
 
 
 
6
 
7
- from .base import Retriever, Context
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
 
12
  class DenseRetriever(Retriever):
13
- """A dense vector retriever using FAISS (placeholder implementation)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def __init__(self, index_path: str | None):
16
- if index_path is None:
17
- raise ValueError("Dense retriever requires a FAISS index file.")
18
- import faiss # pylint: disable=import-error
19
 
20
- self.index = faiss.read_index(index_path)
21
- logger.info("DenseRetriever initialised with FAISS index: %s", index_path)
 
 
 
 
22
 
 
 
 
23
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
24
- # TODO: embed the query via a sentence transformer or similar.
25
- raise NotImplementedError("DenseRetriever embedding is not implemented yet.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dense vector retriever with automatic FAISS index construction."""
2
 
3
  from __future__ import annotations
4
+
5
+ import json
6
  import logging
7
+ import os
8
+ from pathlib import Path
9
+ from typing import List, Optional, Sequence, Union
10
+
11
+ import faiss # type: ignore
12
+ import numpy as np
13
+ from sentence_transformers import SentenceTransformer
14
 
15
+ from .base import Context, Retriever
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
  class DenseRetriever(Retriever):
21
+ """Sentence-Transformers + FAISS ANN search.
22
+
23
+ * If `faiss_index` does **not** exist, it is built from `doc_store`.
24
+ * Embedding model (and its cache location) are configurable.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ faiss_index: Union[str, Path],
30
+ *,
31
+ doc_store: Union[str, Path],
32
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
33
+ embedder_cache: Optional[Union[str, Path]] = None,
34
+ device: str = "cpu",
35
+ ):
36
+ self.faiss_index = Path(faiss_index)
37
+ self.doc_store = Path(doc_store)
38
+
39
+ # ------------------------------------------------------------------
40
+ # Sentence-Transformers embedder
41
+ # ------------------------------------------------------------------
42
+ self.embedder = SentenceTransformer(
43
+ model_name,
44
+ device=device,
45
+ cache_folder=str(embedder_cache) if embedder_cache else None,
46
+ )
47
+ logger.info("Embedder '%s' ready (device=%s)", model_name, device)
48
+
49
+ # ------------------------------------------------------------------
50
+ # Build FAISS index if absent
51
+ # ------------------------------------------------------------------
52
+ if not self.faiss_index.exists():
53
+ logger.info("FAISS index %s missing – building ...", self.faiss_index)
54
+ self._build_index()
55
 
56
+ self.index = faiss.read_index(str(self.faiss_index))
57
+ logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
 
 
58
 
59
+ # Keep doc texts in memory for convenience
60
+ self._texts: List[str] = []
61
+ with self.doc_store.open() as f:
62
+ for line in f:
63
+ obj = json.loads(line)
64
+ self._texts.append(obj.get("text", ""))
65
 
66
+ # ------------------------------------------------------------------ #
67
+ # Public API
68
+ # ------------------------------------------------------------------ #
69
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
70
+ vec = self._embed(query)
71
+ vec = np.asarray(vec, dtype="float32")[None, :]
72
+ dists, idxs = self.index.search(vec, top_k)
73
+ dists, idxs = dists[0], idxs[0]
74
+
75
+ results: List[Context] = []
76
+ for i, score in zip(idxs, dists):
77
+ if i == -1:
78
+ continue
79
+ if self.index.metric_type == faiss.METRIC_L2:
80
+ score = -score
81
+ text = self._texts[i] if i < len(self._texts) else ""
82
+ results.append(Context(id=str(i), text=text, score=float(score)))
83
+
84
+ results.sort(key=lambda c: c.score, reverse=True)
85
+ return results
86
+
87
+ # ------------------------------------------------------------------ #
88
+ # Internal helpers
89
+ # ------------------------------------------------------------------ #
90
+ def _embed(self, text: str) -> Sequence[float]:
91
+ return self.embedder.encode(text, normalize_embeddings=True).tolist()
92
+
93
+ def _build_index(self):
94
+ """Read all texts, embed them, and write a FAISS IP index."""
95
+ logger.info("Reading documents from %s", self.doc_store)
96
+ ids, vectors = [], []
97
+ with self.doc_store.open() as f:
98
+ for line in f:
99
+ obj = json.loads(line)
100
+ ids.append(int(obj["id"]))
101
+ vectors.append(obj["text"])
102
+
103
+ logger.info("Embedding %d documents ...", len(ids))
104
+ embs = self.embedder.encode(
105
+ vectors,
106
+ batch_size=128,
107
+ show_progress_bar=True,
108
+ normalize_embeddings=True,
109
+ ).astype("float32")
110
+
111
+ logger.info("Creating FAISS index (Inner-Product)")
112
+ index = faiss.IndexFlatIP(embs.shape[1])
113
+ index.add(embs)
114
+ faiss.write_index(index, str(self.faiss_index))
115
+ logger.info("Saved FAISS index to %s", self.faiss_index)
tests/conftest.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+ import tempfile
4
+ from pathlib import Path
5
+ from types import SimpleNamespace
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import pytest
10
+
11
+
12
+ @pytest.fixture(scope="session")
13
+ def tmp_doc_store(tmp_path_factory):
14
+ """Create a tiny JSONL doc store for testing."""
15
+ docs = [
16
+ {"id": 0, "text": "Retrieval Augmented Generation combines retrieval and generation."},
17
+ {"id": 1, "text": "BM25 is a strong lexical baseline in information retrieval."},
18
+ {"id": 2, "text": "FAISS enables efficient similarity search over dense embeddings."},
19
+ ]
20
+ doc_path = tmp_path_factory.mktemp("docs") / "docs.jsonl"
21
+ with doc_path.open("w") as f:
22
+ for doc in docs:
23
+ f.write(json.dumps(doc) + "\n")
24
+ return doc_path
25
+
26
+
27
+ class _DummyEmbedder:
28
+ """Fast, deterministic replacement for SentenceTransformer during tests.
29
+
30
+ * Encodes text into a 16‑dim vector with a fixed random seed.
31
+ * Normalises vectors so the retriever workflow (IP metric) is preserved.
32
+ """
33
+
34
+ _dim = 16
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ self.rs = np.random.RandomState(42)
38
+
39
+ def encode(self, texts, **kw):
40
+ if isinstance(texts, str):
41
+ texts = [texts]
42
+ vecs = []
43
+ for t in texts:
44
+ # Simple hash-based seed for determinism
45
+ h = abs(hash(t)) % (2**32)
46
+ self.rs.seed(h)
47
+ v = self.rs.randn(self._dim)
48
+ v = v / np.linalg.norm(v)
49
+ vecs.append(v.astype("float32"))
50
+ return np.stack(vecs)
51
+
52
+ # SentenceTransformer.elasticsearch compatibility
53
+ def __str__(self):
54
+ return "DummyEmbedder"
55
+
56
+ @pytest.fixture(autouse=True)
57
+ def patch_sentence_transformers(monkeypatch):
58
+ """Monkeypatch SentenceTransformer to a lightweight dummy implementation."""
59
+
60
+ # Import path inside our retriever module
61
+ from evaluation.retrievers import dense as dense_mod
62
+
63
+ monkeypatch.setattr(dense_mod, "SentenceTransformer", _DummyEmbedder)
64
+ yield
tests/test_dense_retriever.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ from pathlib import Path
4
+
5
+ from evaluation.retrievers.dense import DenseRetriever
6
+
7
+
8
+ def test_dense_retriever_build_and_search(tmp_doc_store, tmp_path):
9
+ faiss_index = tmp_path / "dense.index"
10
+
11
+ # Build index automatically
12
+ retriever = DenseRetriever(
13
+ faiss_index=faiss_index,
14
+ doc_store=tmp_doc_store,
15
+ model_name="dummy/ignored", # ignored by dummy embedder
16
+ device="cpu",
17
+ )
18
+ assert faiss_index.exists(), "FAISS index should have been auto‑created"
19
+
20
+ # Basic retrieval
21
+ results = retriever.retrieve("What enables similarity search?", top_k=3)
22
+ assert results, "Should return at least one context"
23
+ # Check score ordering descending
24
+ assert all(results[i].score >= results[i + 1].score for i in range(len(results) - 1))
25
+ # IDs must be strings by contract
26
+ assert isinstance(results[0].id, str)
tests/test_metrics.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from evaluation.metrics import (
2
+ precision_at_k,
3
+ recall_at_k,
4
+ mean_reciprocal_rank,
5
+ average_precision,
6
+ rag_score,
7
+ )
8
+
9
+
10
+ def test_retrieval_metrics_simple():
11
+ retrieved = ["a", "b", "c", "d"]
12
+ relevant = {"b", "d"}
13
+
14
+ assert precision_at_k(retrieved, relevant, 2) == 0.5
15
+ assert recall_at_k(retrieved, relevant, 4) == 1.0
16
+ assert mean_reciprocal_rank(retrieved, relevant) == 1 / 2
17
+ assert 0 < average_precision(retrieved, relevant) <= 1
18
+
19
+
20
+ def test_rag_score_harmonic_mean():
21
+ r = {"prec": 0.8, "rec": 0.6}
22
+ g = {"bleu": 0.7}
23
+ s = rag_score(r, g)
24
+ assert 0 <= s <= 1
25
+ # harmonic mean must be less than or equal to arithmetic mean
26
+ assert s <= (sum(r.values()) / len(r) + sum(g.values()) / len(g)) / 2
tests/test_pipeline_end_to_end.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+
5
+ from evaluation.config import PipelineConfig, RetrieverConfig, GeneratorConfig
6
+ from evaluation.pipeline import RAGPipeline
7
+
8
+
9
+ class _DummyGenerator:
10
+ def __init__(self, *args, **kwargs):
11
+ pass
12
+
13
+ def generate(self, *args, **kw):
14
+ return "dummy answer"
15
+
16
+
17
+ def test_pipeline_with_dense(tmp_doc_store, monkeypatch, tmp_path):
18
+ # Monkey‑patch HFGenerator with dummy fast implementation
19
+ from evaluation import generators as gens_pkg # noqa: F401
20
+ from evaluation.generators import hf_generator
21
+
22
+ monkeypatch.setattr(hf_generator, "HFGenerator", _DummyGenerator)
23
+
24
+ cfg = PipelineConfig(
25
+ retriever=RetrieverConfig(
26
+ name="dense",
27
+ top_k=2,
28
+ faiss_index=tmp_path / "dense.idx",
29
+ doc_store=tmp_doc_store,
30
+ device="cpu",
31
+ model_name="dummy/ignored",
32
+ ),
33
+ generator=GeneratorConfig(model_name="dummy"),
34
+ )
35
+
36
+ pipeline = RAGPipeline(cfg)
37
+ result = pipeline("What is BM25?")
38
+ assert result["answer"] == "dummy answer"
39
+ assert len(result["contexts"]) > 0
tests/test_smoke.py DELETED
@@ -1,2 +0,0 @@
1
- def test_smoke():
2
- assert True