refactor: Update retrievers.py to improve retriever configuration and naming conventions
Browse files- libs/retrievers.py +5 -7
libs/retrievers.py
CHANGED
|
@@ -12,7 +12,7 @@ from .config import FAISS_DB_INDEX, BM25_INDEX
|
|
| 12 |
def load_bm25_retriever():
|
| 13 |
with open(BM25_INDEX, "rb") as f:
|
| 14 |
bm25_retriever = pickle.load(f)
|
| 15 |
-
return bm25_retriever
|
| 16 |
|
| 17 |
|
| 18 |
def load_faiss_retriever(embeddings):
|
|
@@ -20,21 +20,19 @@ def load_faiss_retriever(embeddings):
|
|
| 20 |
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
|
| 21 |
)
|
| 22 |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
| 23 |
-
return faiss_retriever
|
| 24 |
|
| 25 |
|
| 26 |
def load_retrievers(embeddings):
|
| 27 |
-
faiss_retriever = load_faiss_retriever(embeddings)
|
| 28 |
-
run_name="FaissRetriever"
|
| 29 |
-
)
|
| 30 |
|
| 31 |
-
bm25_retriever = load_bm25_retriever()
|
| 32 |
|
| 33 |
ensemble_retriever = EnsembleRetriever(
|
| 34 |
retrievers=[bm25_retriever, faiss_retriever],
|
| 35 |
weights=[0.7, 0.3],
|
| 36 |
search_type="mmr",
|
| 37 |
-
)
|
| 38 |
|
| 39 |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
|
| 40 |
compression_retriever = ContextualCompressionRetriever(
|
|
|
|
| 12 |
def load_bm25_retriever():
|
| 13 |
with open(BM25_INDEX, "rb") as f:
|
| 14 |
bm25_retriever = pickle.load(f)
|
| 15 |
+
return bm25_retriever.with_config(run_name="BM25Retriever")
|
| 16 |
|
| 17 |
|
| 18 |
def load_faiss_retriever(embeddings):
|
|
|
|
| 20 |
FAISS_DB_INDEX, embeddings, allow_dangerous_deserialization=True
|
| 21 |
)
|
| 22 |
faiss_retriever = faiss_db.as_retriever(search_type="mmr", search_kwargs={"k": 10})
|
| 23 |
+
return faiss_retriever.with_config(run_name="FaissRetriever")
|
| 24 |
|
| 25 |
|
| 26 |
def load_retrievers(embeddings):
|
| 27 |
+
faiss_retriever = load_faiss_retriever(embeddings)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
bm25_retriever = load_bm25_retriever()
|
| 30 |
|
| 31 |
ensemble_retriever = EnsembleRetriever(
|
| 32 |
retrievers=[bm25_retriever, faiss_retriever],
|
| 33 |
weights=[0.7, 0.3],
|
| 34 |
search_type="mmr",
|
| 35 |
+
)
|
| 36 |
|
| 37 |
compressor = CohereRerank(model="rerank-multilingual-v3.0", top_n=5)
|
| 38 |
compression_retriever = ContextualCompressionRetriever(
|