eudr_retriever / app /vectorstore_interface.py
mtyrrell's picture
updated for test storage module, plus prelim generalized approach to multi data source
08a352f
raw
history blame
3.73 kB
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from gradio_client import Client
import logging
import os
import time
class VectorStoreInterface(ABC):
"""Abstract interface for different vector store implementations."""
@abstractmethod
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search for similar documents."""
pass
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
def __init__(self, space_url: str, collection_name: str, hf_token: Optional[str] = None):
token = os.getenv("HF_TOKEN")
repo_id = space_url
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
if token:
self.client = Client(repo_id, hf_token=token)
else:
self.client = Client(repo_id)
self.collection_name = collection_name
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
"""Search using Hugging Face Spaces MCP API."""
try:
# Use the /search_text endpoint as documented in the API
result = self.client.predict(
query=query,
collection_name=self.collection_name,
model_name=kwargs.get('model_name'),
top_k=top_k,
api_name="/search_text"
)
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
return result
except Exception as e:
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
raise e
# class QdrantVectorStore(VectorStoreInterface):
# """Vector store implementation for direct Qdrant connection."""
# # needs to be generalized for other vector stores (or add a new class for each vector store)
# def __init__(self, host: str, port: int, collection_name: str, api_key: Optional[str] = None):
# from qdrant_client import QdrantClient
# from langchain_community.vectorstores import Qdrant
# self.client = QdrantClient(
# host=host,
# port=port,
# api_key=api_key
# )
# self.collection_name = collection_name
# # Embedding model not implemented
# def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
# """Search using direct Qdrant connection."""
# # Embedding model not implemented
# raise NotImplementedError("Direct Qdrant search needs embedding model configuration")
def create_vectorstore(config: Any) -> VectorStoreInterface:
"""Factory function to create appropriate vector store based on configuration."""
vectorstore_type = config.get("vectorstore", "TYPE")
if vectorstore_type.lower() == "huggingface_spaces":
space_url = config.get("vectorstore", "SPACE_URL")
collection_name = config.get("vectorstore", "COLLECTION_NAME")
hf_token = config.get("vectorstore", "HF_TOKEN", fallback=None)
return HuggingFaceSpacesVectorStore(space_url, collection_name, hf_token)
elif vectorstore_type.lower() == "qdrant":
host = config.get("vectorstore", "HOST")
port = int(config.get("vectorstore", "PORT"))
collection_name = config.get("vectorstore", "COLLECTION_NAME")
api_key = config.get("vectorstore", "API_KEY", fallback=None)
return QdrantVectorStore(host, port, collection_name, api_key)
else:
raise ValueError(f"Unsupported vector store type: {vectorstore_type}")