Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import chromadb | |
| # Create absolute paths for data directories | |
| base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| DB_DIR = os.path.join(base_dir, "data", "chromadb") | |
| os.makedirs(DB_DIR, exist_ok=True) | |
| # Initialize ChromaDB client | |
| client = chromadb.PersistentClient(path=DB_DIR) | |
| def add_embedding(collection_name, doc_id, embedding, metadata=None): | |
| """ | |
| Add embedding to ChromaDB. | |
| Args: | |
| collection_name: Name of the collection | |
| doc_id: Document ID | |
| embedding: Document embedding vector (numpy array) | |
| metadata: Optional metadata | |
| """ | |
| try: | |
| # Convert numpy array to list if needed | |
| if isinstance(embedding, np.ndarray): | |
| embedding_list = embedding.tolist() | |
| else: | |
| embedding_list = embedding | |
| # Use basename for document ID | |
| doc_id = os.path.basename(doc_id) | |
| # Prepare metadata | |
| if metadata is None: | |
| metadata = {"source": doc_id} | |
| # Get or create collection | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| print(f"Using existing collection '{collection_name}'") | |
| except Exception: | |
| collection = client.create_collection(name=collection_name) | |
| print(f"Created new collection '{collection_name}'") | |
| # Add embedding to collection | |
| collection.add( | |
| ids=[doc_id], | |
| embeddings=[embedding_list], | |
| metadatas=[metadata] | |
| ) | |
| print(f"Successfully added '{doc_id}' to collection '{collection_name}'") | |
| return True | |
| except Exception as e: | |
| print(f"Error adding embedding: {e}") | |
| return False | |
| def search_embedding(embedding, collection_name="pdf_images", top_k=5): | |
| """ | |
| Search for similar embeddings in ChromaDB. | |
| Args: | |
| embedding: Query embedding vector (numpy array) | |
| collection_name: Name of the collection to search | |
| top_k: Number of results to return | |
| Returns: | |
| List of matches | |
| """ | |
| try: | |
| # Convert numpy array to list if needed | |
| if isinstance(embedding, np.ndarray): | |
| embedding_list = embedding.tolist() | |
| else: | |
| embedding_list = embedding | |
| # Get collection | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| except Exception as e: | |
| print(f"Collection '{collection_name}' not found: {e}") | |
| return [] | |
| # Query collection | |
| results = collection.query( | |
| query_embeddings=[embedding_list], | |
| n_results=top_k | |
| ) | |
| # Format results | |
| matches = [] | |
| if results["ids"] and len(results["ids"]) > 0: | |
| for i in range(len(results["ids"][0])): | |
| match = { | |
| "id": results["ids"][0][i], | |
| "metadata": results["metadatas"][0][i] if results["metadatas"] else {} | |
| } | |
| matches.append(match) | |
| return matches | |
| except Exception as e: | |
| print(f"Error searching embeddings: {e}") | |
| return [] | |