Spaces:
Sleeping
Sleeping
Commit
·
cdf4160
1
Parent(s):
8521f60
Updated work
Browse files- evaluation/config.py +7 -0
- evaluation/pipeline.py +6 -6
- evaluation/retrievers/bm25.py +65 -7
- evaluation/retrievers/dense.py +102 -12
- tests/conftest.py +64 -0
- tests/test_dense_retriever.py +26 -0
- tests/test_metrics.py +26 -0
- tests/test_pipeline_end_to_end.py +39 -0
- tests/test_smoke.py +0 -2
evaluation/config.py
CHANGED
|
@@ -15,6 +15,13 @@ class RetrieverConfig:
|
|
| 15 |
faiss_index: Optional[Path] = None
|
| 16 |
doc_store: Optional[Path] = None
|
| 17 |
device: str = "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
@dataclass
|
|
|
|
| 15 |
faiss_index: Optional[Path] = None
|
| 16 |
doc_store: Optional[Path] = None
|
| 17 |
device: str = "cpu"
|
| 18 |
+
|
| 19 |
+
# hybrid only
|
| 20 |
+
alpha: float = 0.5 # sparse ↔ dense weight
|
| 21 |
+
|
| 22 |
+
# dense-only
|
| 23 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 24 |
+
embedder_cache: Optional[Path] = None
|
| 25 |
|
| 26 |
|
| 27 |
@dataclass
|
evaluation/pipeline.py
CHANGED
|
@@ -44,16 +44,16 @@ class RAGPipeline:
|
|
| 44 |
# ---------------------------------------------------------------------
|
| 45 |
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
|
| 46 |
name = cfg.retriever.name
|
|
|
|
| 47 |
if name == "bm25":
|
| 48 |
-
return bm25.BM25Retriever(str(
|
| 49 |
if name == "dense":
|
| 50 |
-
return dense.DenseRetriever(str(
|
| 51 |
if name == "hybrid":
|
| 52 |
-
# In a real setting one would supply two paths; simplified here.
|
| 53 |
return hybrid.HybridRetriever(
|
| 54 |
-
bm25_idx=str(
|
| 55 |
-
dense_idx=str(
|
| 56 |
-
alpha=
|
| 57 |
)
|
| 58 |
raise ValueError(f"Unsupported retriever '{name}'")
|
| 59 |
|
|
|
|
| 44 |
# ---------------------------------------------------------------------
|
| 45 |
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
|
| 46 |
name = cfg.retriever.name
|
| 47 |
+
r=cfg.retriever
|
| 48 |
if name == "bm25":
|
| 49 |
+
return bm25.BM25Retriever(index_path=str(r.bm25_index))
|
| 50 |
if name == "dense":
|
| 51 |
+
return dense.DenseRetriever(faiss_index=str(r.faiss_index),doc_store=r.doc_store,model_name=r.model_name,embedder_cache=r.embedder_cache,device=r.device)
|
| 52 |
if name == "hybrid":
|
|
|
|
| 53 |
return hybrid.HybridRetriever(
|
| 54 |
+
bm25_idx=str(r.bm25_index),
|
| 55 |
+
dense_idx=str(r.faiss_index),
|
| 56 |
+
alpha=r.alpha,
|
| 57 |
)
|
| 58 |
raise ValueError(f"Unsupported retriever '{name}'")
|
| 59 |
|
evaluation/retrievers/bm25.py
CHANGED
|
@@ -1,30 +1,88 @@
|
|
| 1 |
-
"""BM25 sparse retriever backed by Pyserini SimpleSearcher."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
-
from typing import List
|
| 5 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from pyserini.search import SimpleSearcher
|
| 8 |
|
| 9 |
from .base import Retriever, Context
|
| 10 |
|
| 11 |
-
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
| 15 |
class BM25Retriever(Retriever):
|
| 16 |
-
"""
|
| 17 |
|
| 18 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
if index_path is None:
|
| 20 |
-
raise ValueError("
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
self.searcher.set_bm25()
|
| 23 |
logger.info("BM25Retriever initialised with index: %s", index_path)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 26 |
hits = self.searcher.search(query, k=top_k)
|
| 27 |
return [
|
| 28 |
Context(id=str(hit.docid), text=hit.raw, score=hit.score) # type: ignore[attr-defined]
|
| 29 |
for hit in hits
|
| 30 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BM25 sparse retriever backed by Pyserini SimpleSearcher, with auto-indexing."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
| 4 |
import logging
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Optional
|
| 9 |
|
| 10 |
from pyserini.search import SimpleSearcher
|
| 11 |
|
| 12 |
from .base import Retriever, Context
|
| 13 |
|
|
|
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class BM25Retriever(Retriever):
|
| 18 |
+
"""Pyserini BM25 searcher that will create the Lucene index on-the-fly."""
|
| 19 |
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
index_path: str | os.PathLike | None,
|
| 23 |
+
*,
|
| 24 |
+
doc_store_path: Optional[str | os.PathLike] = None,
|
| 25 |
+
threads: int = 4,
|
| 26 |
+
):
|
| 27 |
if index_path is None:
|
| 28 |
+
raise ValueError("`index_path` (directory) is required.")
|
| 29 |
+
|
| 30 |
+
index_path = Path(index_path)
|
| 31 |
+
|
| 32 |
+
# ------------------------------------------------------------------
|
| 33 |
+
# Build index if it does not already exist
|
| 34 |
+
# ------------------------------------------------------------------
|
| 35 |
+
if not index_path.exists():
|
| 36 |
+
if doc_store_path is None:
|
| 37 |
+
raise FileNotFoundError(
|
| 38 |
+
f"BM25 index {index_path} not found and no `doc_store_path` supplied."
|
| 39 |
+
)
|
| 40 |
+
logger.info("BM25 index %s missing – building from %s ...",
|
| 41 |
+
index_path, doc_store_path)
|
| 42 |
+
self._build_index(Path(doc_store_path), index_path, threads)
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------------
|
| 45 |
+
# Searcher
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
self.searcher = SimpleSearcher(str(index_path))
|
| 48 |
self.searcher.set_bm25()
|
| 49 |
logger.info("BM25Retriever initialised with index: %s", index_path)
|
| 50 |
|
| 51 |
+
# ------------------------------------------------------------------ #
|
| 52 |
+
# Public API
|
| 53 |
+
# ------------------------------------------------------------------ #
|
| 54 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 55 |
hits = self.searcher.search(query, k=top_k)
|
| 56 |
return [
|
| 57 |
Context(id=str(hit.docid), text=hit.raw, score=hit.score) # type: ignore[attr-defined]
|
| 58 |
for hit in hits
|
| 59 |
]
|
| 60 |
+
|
| 61 |
+
# ------------------------------------------------------------------ #
|
| 62 |
+
# Helpers
|
| 63 |
+
# ------------------------------------------------------------------ #
|
| 64 |
+
@staticmethod
|
| 65 |
+
def _build_index(
|
| 66 |
+
doc_store: Path,
|
| 67 |
+
index_dir: Path,
|
| 68 |
+
threads: int,
|
| 69 |
+
):
|
| 70 |
+
"""Call Pyserini’s CLI to build a Lucene index from JSONL documents.
|
| 71 |
+
|
| 72 |
+
`doc_store` must be a JSONL file or directory containing JSONL files
|
| 73 |
+
with at least {"id": ..., "text": ...} per line.
|
| 74 |
+
"""
|
| 75 |
+
index_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
cmd = [
|
| 78 |
+
"python", "-m", "pyserini.index",
|
| 79 |
+
"-collection", "JsonCollection",
|
| 80 |
+
"-generator", "DefaultLuceneDocumentGenerator",
|
| 81 |
+
"-input", str(doc_store),
|
| 82 |
+
"-index", str(index_dir),
|
| 83 |
+
"-threads", str(threads),
|
| 84 |
+
"-storePositions", "-storeDocvectors", "-storeRaw",
|
| 85 |
+
]
|
| 86 |
+
logger.info("Running Pyserini indexer: %s", " ".join(cmd))
|
| 87 |
+
subprocess.run(cmd, check=True) # raises if indexing fails
|
| 88 |
+
logger.info("Finished building Lucene index in %s", index_dir)
|
evaluation/retrievers/dense.py
CHANGED
|
@@ -1,25 +1,115 @@
|
|
| 1 |
-
"""Dense vector retriever
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
-
|
|
|
|
| 5 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
from .base import
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
|
| 12 |
class DenseRetriever(Retriever):
|
| 13 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
raise ValueError("Dense retriever requires a FAISS index file.")
|
| 18 |
-
import faiss # pylint: disable=import-error
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dense vector retriever with automatic FAISS index construction."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Optional, Sequence, Union
|
| 10 |
+
|
| 11 |
+
import faiss # type: ignore
|
| 12 |
+
import numpy as np
|
| 13 |
+
from sentence_transformers import SentenceTransformer
|
| 14 |
|
| 15 |
+
from .base import Context, Retriever
|
| 16 |
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
|
| 20 |
class DenseRetriever(Retriever):
|
| 21 |
+
"""Sentence-Transformers + FAISS ANN search.
|
| 22 |
+
|
| 23 |
+
* If `faiss_index` does **not** exist, it is built from `doc_store`.
|
| 24 |
+
* Embedding model (and its cache location) are configurable.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
faiss_index: Union[str, Path],
|
| 30 |
+
*,
|
| 31 |
+
doc_store: Union[str, Path],
|
| 32 |
+
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 33 |
+
embedder_cache: Optional[Union[str, Path]] = None,
|
| 34 |
+
device: str = "cpu",
|
| 35 |
+
):
|
| 36 |
+
self.faiss_index = Path(faiss_index)
|
| 37 |
+
self.doc_store = Path(doc_store)
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
# Sentence-Transformers embedder
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
self.embedder = SentenceTransformer(
|
| 43 |
+
model_name,
|
| 44 |
+
device=device,
|
| 45 |
+
cache_folder=str(embedder_cache) if embedder_cache else None,
|
| 46 |
+
)
|
| 47 |
+
logger.info("Embedder '%s' ready (device=%s)", model_name, device)
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------------------------
|
| 50 |
+
# Build FAISS index if absent
|
| 51 |
+
# ------------------------------------------------------------------
|
| 52 |
+
if not self.faiss_index.exists():
|
| 53 |
+
logger.info("FAISS index %s missing – building ...", self.faiss_index)
|
| 54 |
+
self._build_index()
|
| 55 |
|
| 56 |
+
self.index = faiss.read_index(str(self.faiss_index))
|
| 57 |
+
logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
# Keep doc texts in memory for convenience
|
| 60 |
+
self._texts: List[str] = []
|
| 61 |
+
with self.doc_store.open() as f:
|
| 62 |
+
for line in f:
|
| 63 |
+
obj = json.loads(line)
|
| 64 |
+
self._texts.append(obj.get("text", ""))
|
| 65 |
|
| 66 |
+
# ------------------------------------------------------------------ #
|
| 67 |
+
# Public API
|
| 68 |
+
# ------------------------------------------------------------------ #
|
| 69 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 70 |
+
vec = self._embed(query)
|
| 71 |
+
vec = np.asarray(vec, dtype="float32")[None, :]
|
| 72 |
+
dists, idxs = self.index.search(vec, top_k)
|
| 73 |
+
dists, idxs = dists[0], idxs[0]
|
| 74 |
+
|
| 75 |
+
results: List[Context] = []
|
| 76 |
+
for i, score in zip(idxs, dists):
|
| 77 |
+
if i == -1:
|
| 78 |
+
continue
|
| 79 |
+
if self.index.metric_type == faiss.METRIC_L2:
|
| 80 |
+
score = -score
|
| 81 |
+
text = self._texts[i] if i < len(self._texts) else ""
|
| 82 |
+
results.append(Context(id=str(i), text=text, score=float(score)))
|
| 83 |
+
|
| 84 |
+
results.sort(key=lambda c: c.score, reverse=True)
|
| 85 |
+
return results
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------------------------------ #
|
| 88 |
+
# Internal helpers
|
| 89 |
+
# ------------------------------------------------------------------ #
|
| 90 |
+
def _embed(self, text: str) -> Sequence[float]:
|
| 91 |
+
return self.embedder.encode(text, normalize_embeddings=True).tolist()
|
| 92 |
+
|
| 93 |
+
def _build_index(self):
|
| 94 |
+
"""Read all texts, embed them, and write a FAISS IP index."""
|
| 95 |
+
logger.info("Reading documents from %s", self.doc_store)
|
| 96 |
+
ids, vectors = [], []
|
| 97 |
+
with self.doc_store.open() as f:
|
| 98 |
+
for line in f:
|
| 99 |
+
obj = json.loads(line)
|
| 100 |
+
ids.append(int(obj["id"]))
|
| 101 |
+
vectors.append(obj["text"])
|
| 102 |
+
|
| 103 |
+
logger.info("Embedding %d documents ...", len(ids))
|
| 104 |
+
embs = self.embedder.encode(
|
| 105 |
+
vectors,
|
| 106 |
+
batch_size=128,
|
| 107 |
+
show_progress_bar=True,
|
| 108 |
+
normalize_embeddings=True,
|
| 109 |
+
).astype("float32")
|
| 110 |
+
|
| 111 |
+
logger.info("Creating FAISS index (Inner-Product)")
|
| 112 |
+
index = faiss.IndexFlatIP(embs.shape[1])
|
| 113 |
+
index.add(embs)
|
| 114 |
+
faiss.write_index(index, str(self.faiss_index))
|
| 115 |
+
logger.info("Saved FAISS index to %s", self.faiss_index)
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from types import SimpleNamespace
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@pytest.fixture(scope="session")
|
| 13 |
+
def tmp_doc_store(tmp_path_factory):
|
| 14 |
+
"""Create a tiny JSONL doc store for testing."""
|
| 15 |
+
docs = [
|
| 16 |
+
{"id": 0, "text": "Retrieval Augmented Generation combines retrieval and generation."},
|
| 17 |
+
{"id": 1, "text": "BM25 is a strong lexical baseline in information retrieval."},
|
| 18 |
+
{"id": 2, "text": "FAISS enables efficient similarity search over dense embeddings."},
|
| 19 |
+
]
|
| 20 |
+
doc_path = tmp_path_factory.mktemp("docs") / "docs.jsonl"
|
| 21 |
+
with doc_path.open("w") as f:
|
| 22 |
+
for doc in docs:
|
| 23 |
+
f.write(json.dumps(doc) + "\n")
|
| 24 |
+
return doc_path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class _DummyEmbedder:
|
| 28 |
+
"""Fast, deterministic replacement for SentenceTransformer during tests.
|
| 29 |
+
|
| 30 |
+
* Encodes text into a 16‑dim vector with a fixed random seed.
|
| 31 |
+
* Normalises vectors so the retriever workflow (IP metric) is preserved.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
_dim = 16
|
| 35 |
+
|
| 36 |
+
def __init__(self, *args, **kwargs):
|
| 37 |
+
self.rs = np.random.RandomState(42)
|
| 38 |
+
|
| 39 |
+
def encode(self, texts, **kw):
|
| 40 |
+
if isinstance(texts, str):
|
| 41 |
+
texts = [texts]
|
| 42 |
+
vecs = []
|
| 43 |
+
for t in texts:
|
| 44 |
+
# Simple hash-based seed for determinism
|
| 45 |
+
h = abs(hash(t)) % (2**32)
|
| 46 |
+
self.rs.seed(h)
|
| 47 |
+
v = self.rs.randn(self._dim)
|
| 48 |
+
v = v / np.linalg.norm(v)
|
| 49 |
+
vecs.append(v.astype("float32"))
|
| 50 |
+
return np.stack(vecs)
|
| 51 |
+
|
| 52 |
+
# SentenceTransformer.elasticsearch compatibility
|
| 53 |
+
def __str__(self):
|
| 54 |
+
return "DummyEmbedder"
|
| 55 |
+
|
| 56 |
+
@pytest.fixture(autouse=True)
|
| 57 |
+
def patch_sentence_transformers(monkeypatch):
|
| 58 |
+
"""Monkeypatch SentenceTransformer to a lightweight dummy implementation."""
|
| 59 |
+
|
| 60 |
+
# Import path inside our retriever module
|
| 61 |
+
from evaluation.retrievers import dense as dense_mod
|
| 62 |
+
|
| 63 |
+
monkeypatch.setattr(dense_mod, "SentenceTransformer", _DummyEmbedder)
|
| 64 |
+
yield
|
tests/test_dense_retriever.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faiss
|
| 2 |
+
import numpy as np
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from evaluation.retrievers.dense import DenseRetriever
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def test_dense_retriever_build_and_search(tmp_doc_store, tmp_path):
|
| 9 |
+
faiss_index = tmp_path / "dense.index"
|
| 10 |
+
|
| 11 |
+
# Build index automatically
|
| 12 |
+
retriever = DenseRetriever(
|
| 13 |
+
faiss_index=faiss_index,
|
| 14 |
+
doc_store=tmp_doc_store,
|
| 15 |
+
model_name="dummy/ignored", # ignored by dummy embedder
|
| 16 |
+
device="cpu",
|
| 17 |
+
)
|
| 18 |
+
assert faiss_index.exists(), "FAISS index should have been auto‑created"
|
| 19 |
+
|
| 20 |
+
# Basic retrieval
|
| 21 |
+
results = retriever.retrieve("What enables similarity search?", top_k=3)
|
| 22 |
+
assert results, "Should return at least one context"
|
| 23 |
+
# Check score ordering descending
|
| 24 |
+
assert all(results[i].score >= results[i + 1].score for i in range(len(results) - 1))
|
| 25 |
+
# IDs must be strings by contract
|
| 26 |
+
assert isinstance(results[0].id, str)
|
tests/test_metrics.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from evaluation.metrics import (
|
| 2 |
+
precision_at_k,
|
| 3 |
+
recall_at_k,
|
| 4 |
+
mean_reciprocal_rank,
|
| 5 |
+
average_precision,
|
| 6 |
+
rag_score,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def test_retrieval_metrics_simple():
|
| 11 |
+
retrieved = ["a", "b", "c", "d"]
|
| 12 |
+
relevant = {"b", "d"}
|
| 13 |
+
|
| 14 |
+
assert precision_at_k(retrieved, relevant, 2) == 0.5
|
| 15 |
+
assert recall_at_k(retrieved, relevant, 4) == 1.0
|
| 16 |
+
assert mean_reciprocal_rank(retrieved, relevant) == 1 / 2
|
| 17 |
+
assert 0 < average_precision(retrieved, relevant) <= 1
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_rag_score_harmonic_mean():
|
| 21 |
+
r = {"prec": 0.8, "rec": 0.6}
|
| 22 |
+
g = {"bleu": 0.7}
|
| 23 |
+
s = rag_score(r, g)
|
| 24 |
+
assert 0 <= s <= 1
|
| 25 |
+
# harmonic mean must be less than or equal to arithmetic mean
|
| 26 |
+
assert s <= (sum(r.values()) / len(r) + sum(g.values()) / len(g)) / 2
|
tests/test_pipeline_end_to_end.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from evaluation.config import PipelineConfig, RetrieverConfig, GeneratorConfig
|
| 6 |
+
from evaluation.pipeline import RAGPipeline
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class _DummyGenerator:
|
| 10 |
+
def __init__(self, *args, **kwargs):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
def generate(self, *args, **kw):
|
| 14 |
+
return "dummy answer"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_pipeline_with_dense(tmp_doc_store, monkeypatch, tmp_path):
|
| 18 |
+
# Monkey‑patch HFGenerator with dummy fast implementation
|
| 19 |
+
from evaluation import generators as gens_pkg # noqa: F401
|
| 20 |
+
from evaluation.generators import hf_generator
|
| 21 |
+
|
| 22 |
+
monkeypatch.setattr(hf_generator, "HFGenerator", _DummyGenerator)
|
| 23 |
+
|
| 24 |
+
cfg = PipelineConfig(
|
| 25 |
+
retriever=RetrieverConfig(
|
| 26 |
+
name="dense",
|
| 27 |
+
top_k=2,
|
| 28 |
+
faiss_index=tmp_path / "dense.idx",
|
| 29 |
+
doc_store=tmp_doc_store,
|
| 30 |
+
device="cpu",
|
| 31 |
+
model_name="dummy/ignored",
|
| 32 |
+
),
|
| 33 |
+
generator=GeneratorConfig(model_name="dummy"),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
pipeline = RAGPipeline(cfg)
|
| 37 |
+
result = pipeline("What is BM25?")
|
| 38 |
+
assert result["answer"] == "dummy answer"
|
| 39 |
+
assert len(result["contexts"]) > 0
|
tests/test_smoke.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
def test_smoke():
|
| 2 |
-
assert True
|
|
|
|
|
|
|
|
|