Spaces:
Sleeping
Sleeping
"""High‑level RAG pipeline orchestration.""" | |
from __future__ import annotations | |
import logging | |
from typing import Dict, Any, List | |
from .config import PipelineConfig | |
from .retrievers import bm25, dense, hybrid | |
from .generators.hf_generator import HFGenerator | |
from .retrievers.base import Retriever, Context | |
from .rerankers.cross_encoder import CrossEncoderReranker | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
class RAGPipeline: | |
"""Run retrieval → generation → scoring in a single object.""" | |
def __init__(self, cfg: PipelineConfig): | |
self.cfg = cfg | |
self.retriever: Retriever = self._build_retriever(cfg) | |
self.generator = HFGenerator( | |
model_name=cfg.generator.model_name, device=cfg.generator.device | |
) | |
if cfg.reranker.enable: | |
self.reranker = CrossEncoderReranker( | |
cfg.reranker.model_name, | |
device=cfg.reranker.device, | |
max_length=cfg.reranker.max_length, | |
) | |
else: | |
self.reranker = None | |
# --------------------------------------------------------------------- | |
# Public API | |
# --------------------------------------------------------------------- | |
def run(self, question: str) -> Dict[str, Any]: | |
logger.info("Question: %s", question) | |
# 1. raw retrieval | |
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k | |
initial: List[Context] = self.retriever.retrieve(question, top_k=k_first) | |
raw_hits = [ | |
{"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)} | |
for c in initial | |
] | |
# 2. reranking (if enabled) | |
if self.reranker: | |
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k | |
reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k) | |
reranked_hits = [ | |
{ | |
"text": c.text, | |
"id": c.id, | |
"score": getattr(c, "cross_encoder_score", None), | |
} | |
for c in reranked | |
] | |
contexts_for_gen = reranked | |
else: | |
reranked_hits = [] | |
contexts_for_gen = initial | |
# 3. generation | |
answer = self.generator.generate( | |
question, | |
[c.text for c in contexts_for_gen], | |
max_new_tokens=self.cfg.generator.max_new_tokens, | |
temperature=self.cfg.generator.temperature, | |
) | |
return { | |
"question": question, | |
"raw_retrieval": raw_hits, | |
"reranked": reranked_hits, | |
"contexts": [c.text for c in contexts_for_gen], | |
"answer": answer, | |
} | |
__call__ = run # alias | |
def run_queries(self, queries: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
"""Accepts a list of {'question': str, 'id': Any}, returns list of result dicts.""" | |
results: list[dict[str, Any]] = [] | |
for entry in queries: | |
q = entry.get("question", "") | |
doc_id = entry.get("id") | |
answer = self.run(q) | |
results.append({"id": doc_id, "question": q, "answer": answer}) | |
return results | |
# --------------------------------------------------------------------- | |
# Private helpers | |
# --------------------------------------------------------------------- | |
def _build_retriever(self, cfg: PipelineConfig) -> Retriever: | |
r=cfg.retriever | |
name = r.name | |
if name == "bm25": | |
return bm25.BM25Retriever(bm25_idx=str(r.bm25_idx), doc_store=str(r.doc_store)) | |
if name == "dense": | |
return dense.DenseRetriever( | |
faiss_index=str(r.faiss_index), | |
doc_store=str(r.doc_store), | |
model_name=r.model_name, | |
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, | |
device=r.device, | |
) | |
if name == "hybrid": | |
return hybrid.HybridRetriever( | |
str(r.bm25_index), | |
str(r.faiss_index), | |
doc_store=str(r.doc_store), | |
alpha=r.alpha, | |
model_name=r.model_name, | |
embedder_cache=str(r.embedder_cache) if r.embedder_cache else None, | |
device=r.device, | |
) | |
raise ValueError(f"Unsupported retriever '{name}'") | |
def _retrieve(self, question: str) -> List[Context]: | |
logger.info("Retrieving top‑%d passages", self.cfg.retriever.top_k) | |
k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k | |
initial = self.retriever.retrieve(question, top_k=k_first) | |
if self.reranker: | |
final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k | |
logger.info("Re-ranking %d docs with cross-encoder ...", len(initial)) | |
initial = self.reranker.rerank(question, initial, k=final_k) | |
return initial | |
def _generate(self, question: str, contexts: List[Context]) -> str: | |
texts = [c.text for c in contexts] | |
logger.info("Generating answer with %d context passages", len(texts)) | |
return self.generator.generate( | |
question, | |
texts, | |
max_new_tokens=self.cfg.generator.max_new_tokens, | |
temperature=self.cfg.generator.temperature, | |
) | |