File size: 2,323 Bytes
80b95e8 593e022 80b95e8 593e022 80b95e8 593e022 80b95e8 593e022 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
"""
Retriever module for semantic document search using FAISS.
Provides functions to perform similarity-based lookups over embedded document vectors.
Integrates with FAISS for efficient vector search and returns relevant document matches.
"""
import os
import pickle
import faiss
from sentence_transformers import SentenceTransformer
def search_documents(query: str, config: dict):
"""
Search for semantically similar documents using FAISS index.
Args:
query (str): Search query
config (dict): Configuration dictionary
Returns:
list: List of relevant text chunks with similarity scores
"""
# Check if FAISS index exists
if not os.path.exists("vector_cache/faiss_index.bin"):
print("No FAISS index found. Please run 'init' command first.")
return []
try:
# Load FAISS index and metadata
index = faiss.read_index("vector_cache/faiss_index.bin")
with open("vector_cache/metadata.pkl", "rb") as f:
metadata = pickle.load(f)
texts = metadata["texts"]
filenames = metadata["filenames"]
# Embed the query
model = SentenceTransformer(config["embedding"]["model_path"])
query_embedding = model.encode([query]).astype("float32")
faiss.normalize_L2(query_embedding)
# Search similar documents
top_k = config.get("retrieval", {}).get("top_k", 5)
similarity_threshold = config.get("retrieval", {}).get(
"similarity_threshold", 0.75
)
scores, indices = index.search(query_embedding, top_k)
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if score >= similarity_threshold:
results.append(
f"[{filenames[idx]}] (score: {score:.3f}): {texts[idx][:200]}..."
)
else:
break
if not results:
results.append(f"No matches found above threshold {similarity_threshold}")
return results
except (
FileNotFoundError,
pickle.UnpicklingError,
KeyError,
ValueError,
) as e:
print(f"Error during search: {e}")
return [f"Search failed: {e}"]
|