talk_to_pdf / app /chroma_utils.py
sapatevaibhav
- Implement image embedding using CLIP and a fallback method
9089af6
import os
import numpy as np
import chromadb
# Create absolute paths for data directories
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DB_DIR = os.path.join(base_dir, "data", "chromadb")
os.makedirs(DB_DIR, exist_ok=True)
# Initialize ChromaDB client
client = chromadb.PersistentClient(path=DB_DIR)
def add_embedding(collection_name, doc_id, embedding, metadata=None):
"""
Add embedding to ChromaDB.
Args:
collection_name: Name of the collection
doc_id: Document ID
embedding: Document embedding vector (numpy array)
metadata: Optional metadata
"""
try:
# Convert numpy array to list if needed
if isinstance(embedding, np.ndarray):
embedding_list = embedding.tolist()
else:
embedding_list = embedding
# Use basename for document ID
doc_id = os.path.basename(doc_id)
# Prepare metadata
if metadata is None:
metadata = {"source": doc_id}
# Get or create collection
try:
collection = client.get_collection(name=collection_name)
print(f"Using existing collection '{collection_name}'")
except Exception:
collection = client.create_collection(name=collection_name)
print(f"Created new collection '{collection_name}'")
# Add embedding to collection
collection.add(
ids=[doc_id],
embeddings=[embedding_list],
metadatas=[metadata]
)
print(f"Successfully added '{doc_id}' to collection '{collection_name}'")
return True
except Exception as e:
print(f"Error adding embedding: {e}")
return False
def search_embedding(embedding, collection_name="pdf_images", top_k=5):
"""
Search for similar embeddings in ChromaDB.
Args:
embedding: Query embedding vector (numpy array)
collection_name: Name of the collection to search
top_k: Number of results to return
Returns:
List of matches
"""
try:
# Convert numpy array to list if needed
if isinstance(embedding, np.ndarray):
embedding_list = embedding.tolist()
else:
embedding_list = embedding
# Get collection
try:
collection = client.get_collection(name=collection_name)
except Exception as e:
print(f"Collection '{collection_name}' not found: {e}")
return []
# Query collection
results = collection.query(
query_embeddings=[embedding_list],
n_results=top_k
)
# Format results
matches = []
if results["ids"] and len(results["ids"]) > 0:
for i in range(len(results["ids"][0])):
match = {
"id": results["ids"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {}
}
matches.append(match)
return matches
except Exception as e:
print(f"Error searching embeddings: {e}")
return []