RAG_Eval / evaluation /pipeline.py
Rom89823974978's picture
Updated codebase
12409b1
"""High‑level RAG pipeline orchestration."""
from __future__ import annotations
import logging
from typing import Dict, Any, List
from .config import PipelineConfig
from .retrievers import bm25, dense, hybrid
from .generators.hf_generator import HFGenerator
from .retrievers.base import Retriever, Context
from .rerankers.cross_encoder import CrossEncoderReranker
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class RAGPipeline:
"""Run retrieval → generation → scoring in a single object."""
def __init__(self, cfg: PipelineConfig):
self.cfg = cfg
self.retriever: Retriever = self._build_retriever(cfg)
self.generator = HFGenerator(
model_name=cfg.generator.model_name, device=cfg.generator.device
)
if cfg.reranker.enable:
self.reranker = CrossEncoderReranker(
cfg.reranker.model_name,
device=cfg.reranker.device,
max_length=cfg.reranker.max_length,
)
else:
self.reranker = None
# ---------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------
def run(self, question: str) -> Dict[str, Any]:
logger.info("Question: %s", question)
# 1. raw retrieval
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
initial: List[Context] = self.retriever.retrieve(question, top_k=k_first)
raw_hits = [
{"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)}
for c in initial
]
# 2. reranking (if enabled)
if self.reranker:
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k)
reranked_hits = [
{
"text": c.text,
"id": c.id,
"score": getattr(c, "cross_encoder_score", None),
}
for c in reranked
]
contexts_for_gen = reranked
else:
reranked_hits = []
contexts_for_gen = initial
# 3. generation
answer = self.generator.generate(
question,
[c.text for c in contexts_for_gen],
max_new_tokens=self.cfg.generator.max_new_tokens,
temperature=self.cfg.generator.temperature,
)
return {
"question": question,
"raw_retrieval": raw_hits,
"reranked": reranked_hits,
"contexts": [c.text for c in contexts_for_gen],
"answer": answer,
}
__call__ = run # alias
def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Accepts a list of {'question': str, 'id': Any}, returns list of result dicts."""
results: list[dict[str, Any]] = []
for entry in queries:
q = entry.get("question", "")
doc_id = entry.get("id")
answer = self.run(q)
results.append({"id": doc_id, "question": q, "answer": answer})
return results
# ---------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------
def _build_retriever(self, cfg: PipelineConfig) -> Retriever:
r=cfg.retriever
name = r.name
if name == "bm25":
return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store))
if name == "dense":
return dense.DenseRetriever(
faiss_index=str(r.faiss_index),
doc_store=str(r.doc_store),
model_name=r.model_name,
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
device=r.device,
)
if name == "hybrid":
return hybrid.HybridRetriever(
str(r.bm25_index),
str(r.faiss_index),
doc_store=str(r.doc_store),
alpha=r.alpha,
model_name=r.model_name,
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None,
device=r.device,
)
raise ValueError(f"Unsupported retriever '{name}'")
def _retrieve(self, question: str) -> List[Context]:
logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k)
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
initial = self.retriever.retrieve(question, top_k=k_first)
if self.reranker:
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
logger.info("Re-ranking %d docs with cross-encoder ...", len(initial))
initial = self.reranker.rerank(question, initial, k=final_k)
return initial
def _generate(self, question: str, contexts: List[Context]) -> str:
texts = [c.text for c in contexts]
logger.info("Generating answer with %d context passages", len(texts))
return self.generator.generate(
question,
texts,
max_new_tokens=self.cfg.generator.max_new_tokens,
temperature=self.cfg.generator.temperature,
)