23RAG7 / retrieval.py
cb1716pics's picture
Upload 7 files
1b04b96 verified
raw
history blame
1.24 kB
import json
import numpy as np
from langchain.schema import Document
from langchain.vectorstores import faiss
from data_processing import embedding_model, index
# Retrieval Function
def retrieve_documents(query, top_k=5):
# Embed the query
query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
# Search in FAISS (top 5 results)
_, nearest_indices = index.search(query_embedding, top_k)
# Load document metadata
with open(f"data_local\rag7_docs.json", "r") as f:
documents = json.load(f) # Contains all documents for this dataset
# Retrieve the actual documents and create Document objects
retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]
return retrieved_docs
def remove_duplicate_documents(documents):
unique_documents = []
seen_documents = set() # To keep track of seen documents
for doc in documents:
# Using the page_content as a unique identifier for deduplication
doc_content = doc.page_content
if doc_content not in seen_documents:
unique_documents.append(doc)
seen_documents.add(doc_content)
return unique_documents