23RAG7 / retrieval.py
cb1716pics's picture
Upload retrieval.py
6c9babd verified
raw
history blame
1.2 kB
import json
import numpy as np
from langchain.schema import Document
import faiss
from data_processing import embedding_model #, index, actual_docs
retrieved_docs = None
# Retrieval Function
def retrieve_documents(query, top_k=5):
faiss_index_path = f"data_local/rag7_index.faiss"
index = faiss.read_index(faiss_index_path)
query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
_, nearest_indices = index.search(query_embedding, top_k)
with open(f"data_local/rag7_docs.json", "r") as f:
documents = json.load(f) # Contains all documents for this dataset
retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
return retrieved_docs
def remove_duplicate_documents(documents):
unique_documents = []
seen_documents = set() # To keep track of seen documents
for doc in documents:
# Using the page_content as a unique identifier for deduplication
doc_content = doc.page_content
if doc_content not in seen_documents:
unique_documents.append(doc)
seen_documents.add(doc_content)
return unique_documents