MikeMann commited on
Commit
fbf9ef6
·
1 Parent(s): 773aee4

added Hybrid Search with BM25

Browse files
Files changed (1) hide show
  1. app.py +32 -3
app.py CHANGED
@@ -28,6 +28,7 @@ from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
28
  from langchain_core.indexing import index
29
  from langchain_core.vectorstores import VectorStore
30
  from llama_index.core.node_parser import TextSplitter
 
31
  from llama_index.legacy.vector_stores import FaissVectorStore
32
  from pycparser.ply.yacc import token
33
  from ragatouille import RAGPretrainedModel
@@ -90,6 +91,10 @@ class BSIChatbot:
90
  global vectorstore
91
  RAW_KNOWLEDGE_BASE = []
92
 
 
 
 
 
93
  #Embedding, Vector generation and storing:
94
  self.embedding_model = HuggingFaceEmbeddings(
95
  model_name=self.word_and_embed_model_path,
@@ -259,6 +264,19 @@ class BSIChatbot:
259
  print(f"printing first chunk to see whats inside: {retrieved_chunks[0]}")
260
  return retrieved_chunks
261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  def initializeLLM(self):
263
  bnb_config = BitsAndBytesConfig(
264
  load_in_8bit=True,
@@ -292,8 +310,19 @@ class BSIChatbot:
292
  rerankingModel = RAGPretrainedModel.from_pretrained(self.rerankModelPath)
293
 
294
 
295
- def retrieval(self, query, rerankingStep):
296
- retrieved_chunks = self.retrieveSimiliarEmbedding(query)
 
 
 
 
 
 
 
 
 
 
 
297
  retrieved_chunks_text = []
298
  # TODO Irgendwas stimmt hier mit den Listen nicht
299
  for chunk in retrieved_chunks:
@@ -315,7 +344,7 @@ class BSIChatbot:
315
  self.initializeRerankingModel()
316
  print("Starting Reranking Chunks...")
317
  rerankingModel
318
- retrieved_chunks_text = rerankingModel.rerank(query, retrieved_chunks_text, k=5)
319
  retrieved_chunks_text = [chunk["content"] for chunk in retrieved_chunks_text]
320
 
321
  i = 1
 
28
  from langchain_core.indexing import index
29
  from langchain_core.vectorstores import VectorStore
30
  from llama_index.core.node_parser import TextSplitter
31
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
32
  from llama_index.legacy.vector_stores import FaissVectorStore
33
  from pycparser.ply.yacc import token
34
  from ragatouille import RAGPretrainedModel
 
91
  global vectorstore
92
  RAW_KNOWLEDGE_BASE = []
93
 
94
+ #Qdrant:
95
+ #client = QdrantClient(path=saved_db_path)
96
+ #db = Qdrant(client=client, collection_name=self.collection_name, embeddings=embeddings, )
97
+
98
  #Embedding, Vector generation and storing:
99
  self.embedding_model = HuggingFaceEmbeddings(
100
  model_name=self.word_and_embed_model_path,
 
264
  print(f"printing first chunk to see whats inside: {retrieved_chunks[0]}")
265
  return retrieved_chunks
266
 
267
+ def retrieveDocFromFaiss(self):
268
+ global vectorstore
269
+
270
+ all_documents = []
271
+
272
+ # Iteriere über alle IDs im index_to_docstore_id
273
+ for doc_id in vector_store.index_to_docstore_id.values():
274
+ # Hole das Dokument aus dem docstore
275
+ document = vector_store.docstore[doc_id]
276
+ all_documents.append(document)
277
+
278
+ return all_documents
279
+
280
  def initializeLLM(self):
281
  bnb_config = BitsAndBytesConfig(
282
  load_in_8bit=True,
 
310
  rerankingModel = RAGPretrainedModel.from_pretrained(self.rerankModelPath)
311
 
312
 
313
+ def retrieval(self, query, rerankingStep, hybridSearch):
314
+ global vectorstore
315
+ if hybridSearch == True:
316
+ allDocs = self.retrieveDocFromFaiss()
317
+ bm25_retriever = BM25Retriever.from_documents(allDocs.page_content)
318
+ #TODO!
319
+ bm25_retriever.k= 4
320
+ vectordb = vectorstore.as_retriever(search_kwargs={"k":4})
321
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, vectordb], weights=[0.5, 0.5])
322
+ retrieved_chunks = ensemble_retriever.get_relevant_documents(query)
323
+ print("DBG: Number of Chunks retrieved" +len(retrieved_chunks))
324
+ else:
325
+ retrieved_chunks = self.retrieveSimiliarEmbedding(query)
326
  retrieved_chunks_text = []
327
  # TODO Irgendwas stimmt hier mit den Listen nicht
328
  for chunk in retrieved_chunks:
 
344
  self.initializeRerankingModel()
345
  print("Starting Reranking Chunks...")
346
  rerankingModel
347
+ retrieved_chunks_text = rerankingModel.rerank(query, retrieved_chunks_text, k=15)
348
  retrieved_chunks_text = [chunk["content"] for chunk in retrieved_chunks_text]
349
 
350
  i = 1