Spaces:
Sleeping
Sleeping
updated for test storage module, plus prelim generalized approach to multi data source
Browse files- app/main.py +11 -2
- app/retriever.py +27 -62
- app/vectorstore_interface.py +89 -0
- params.cfg +9 -5
- requirements.txt +3 -1
app/main.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from .retriever import retrieve_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
# ---------------------------------------------------------------------
|
| 5 |
# Gradio Interface with MCP support
|
|
@@ -78,7 +87,7 @@ ui = gr.Interface(
|
|
| 78 |
if __name__ == "__main__":
|
| 79 |
ui.launch(
|
| 80 |
server_name="0.0.0.0",
|
| 81 |
-
server_port=7860,
|
| 82 |
mcp_server=True,
|
| 83 |
show_error=True
|
| 84 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from .retriever import retrieve_context, get_vectorstore
|
| 3 |
+
|
| 4 |
+
# Initialize vector store at startup
|
| 5 |
+
print("Initializing vector store connection...")
|
| 6 |
+
try:
|
| 7 |
+
vectorstore = get_vectorstore()
|
| 8 |
+
print("Vector store connection initialized successfully")
|
| 9 |
+
except Exception as e:
|
| 10 |
+
print(f"Failed to initialize vector store: {e}")
|
| 11 |
+
raise
|
| 12 |
|
| 13 |
# ---------------------------------------------------------------------
|
| 14 |
# Gradio Interface with MCP support
|
|
|
|
| 87 |
if __name__ == "__main__":
|
| 88 |
ui.launch(
|
| 89 |
server_name="0.0.0.0",
|
| 90 |
+
server_port=7860, # Different port from reranker
|
| 91 |
mcp_server=True,
|
| 92 |
show_error=True
|
| 93 |
)
|
app/retriever.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import List, Dict, Any, Optional
|
|
| 2 |
from qdrant_client.http import models as rest
|
| 3 |
from langchain.schema import Document
|
| 4 |
from .utils import getconfig
|
|
|
|
| 5 |
import logging
|
| 6 |
|
| 7 |
# Load configuration
|
|
@@ -11,6 +12,20 @@ config = getconfig("params.cfg")
|
|
| 11 |
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
|
| 12 |
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def create_filter(
|
| 15 |
reports: List[str] = None,
|
| 16 |
sources: str = None,
|
|
@@ -74,37 +89,9 @@ def create_filter(
|
|
| 74 |
return rest.Filter(must=conditions)
|
| 75 |
return None
|
| 76 |
|
| 77 |
-
def get_vectorstore():
|
| 78 |
-
"""
|
| 79 |
-
Initialize and return the vectorstore connection.
|
| 80 |
-
This function should be implemented based on your specific vectorstore setup.
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
Vectorstore instance (e.g., Qdrant, Pinecone, etc.)
|
| 84 |
-
"""
|
| 85 |
-
# TODO: Implement based on your external vector database
|
| 86 |
-
# Example for Qdrant:
|
| 87 |
-
# from langchain_community.vectorstores import Qdrant
|
| 88 |
-
# from qdrant_client import QdrantClient
|
| 89 |
-
#
|
| 90 |
-
# client = QdrantClient(
|
| 91 |
-
# host=config.get("vectorstore", "HOST"),
|
| 92 |
-
# port=config.get("vectorstore", "PORT"),
|
| 93 |
-
# api_key=config.get("vectorstore", "API_KEY", fallback=None)
|
| 94 |
-
# )
|
| 95 |
-
#
|
| 96 |
-
# vectorstore = Qdrant(
|
| 97 |
-
# client=client,
|
| 98 |
-
# collection_name=config.get("vectorstore", "COLLECTION_NAME"),
|
| 99 |
-
# embeddings=your_embedding_model # You'll need to configure this
|
| 100 |
-
# )
|
| 101 |
-
#
|
| 102 |
-
# return vectorstore
|
| 103 |
-
|
| 104 |
-
raise NotImplementedError("Please implement vectorstore connection based on your setup")
|
| 105 |
-
|
| 106 |
def retrieve_context(
|
| 107 |
query: str,
|
|
|
|
| 108 |
reports: List[str] = None,
|
| 109 |
sources: str = None,
|
| 110 |
subtype: str = None,
|
|
@@ -116,6 +103,7 @@ def retrieve_context(
|
|
| 116 |
|
| 117 |
Args:
|
| 118 |
query: The search query
|
|
|
|
| 119 |
reports: List of specific report filenames to search within
|
| 120 |
sources: Source type to filter by
|
| 121 |
subtype: Document subtype to filter by
|
|
@@ -126,48 +114,25 @@ def retrieve_context(
|
|
| 126 |
List of dictionaries with 'page_content' and 'metadata' keys
|
| 127 |
"""
|
| 128 |
try:
|
| 129 |
-
#
|
| 130 |
-
vectorstore = get_vectorstore()
|
| 131 |
-
|
| 132 |
-
# Create metadata filter
|
| 133 |
-
filter_obj = create_filter(
|
| 134 |
-
reports=reports or [],
|
| 135 |
-
sources=sources,
|
| 136 |
-
subtype=subtype,
|
| 137 |
-
year=year or []
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
# Set up search parameters
|
| 141 |
k = top_k or RETRIEVER_TOP_K
|
|
|
|
|
|
|
| 142 |
search_kwargs = {
|
| 143 |
-
"
|
| 144 |
-
"k": k
|
| 145 |
}
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
retriever = vectorstore.as_retriever(
|
| 152 |
-
search_type="similarity_score_threshold",
|
| 153 |
-
search_kwargs=search_kwargs
|
| 154 |
-
)
|
| 155 |
|
| 156 |
# Perform retrieval
|
| 157 |
-
retrieved_docs
|
| 158 |
|
| 159 |
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
| 160 |
|
| 161 |
-
|
| 162 |
-
results = [
|
| 163 |
-
{
|
| 164 |
-
"page_content": doc.page_content,
|
| 165 |
-
"metadata": doc.metadata
|
| 166 |
-
}
|
| 167 |
-
for doc in retrieved_docs
|
| 168 |
-
]
|
| 169 |
-
|
| 170 |
-
return results
|
| 171 |
|
| 172 |
except Exception as e:
|
| 173 |
logging.error(f"Error during retrieval: {str(e)}")
|
|
|
|
| 2 |
from qdrant_client.http import models as rest
|
| 3 |
from langchain.schema import Document
|
| 4 |
from .utils import getconfig
|
| 5 |
+
from .vectorstore_interface import create_vectorstore, VectorStoreInterface
|
| 6 |
import logging
|
| 7 |
|
| 8 |
# Load configuration
|
|
|
|
| 12 |
RETRIEVER_TOP_K = int(config.get("retriever", "TOP_K"))
|
| 13 |
SCORE_THRESHOLD = float(config.get("retriever", "SCORE_THRESHOLD"))
|
| 14 |
|
| 15 |
+
# Initialize vector store connection at module import time
|
| 16 |
+
logging.info("Initializing vector store connection...")
|
| 17 |
+
vectorstore = create_vectorstore(config)
|
| 18 |
+
logging.info("Vector store connection initialized successfully")
|
| 19 |
+
|
| 20 |
+
def get_vectorstore() -> VectorStoreInterface:
|
| 21 |
+
"""
|
| 22 |
+
Return the pre-initialized vector store connection.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
VectorStoreInterface instance
|
| 26 |
+
"""
|
| 27 |
+
return vectorstore
|
| 28 |
+
|
| 29 |
def create_filter(
|
| 30 |
reports: List[str] = None,
|
| 31 |
sources: str = None,
|
|
|
|
| 89 |
return rest.Filter(must=conditions)
|
| 90 |
return None
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
def retrieve_context(
|
| 93 |
query: str,
|
| 94 |
+
vectorstore,
|
| 95 |
reports: List[str] = None,
|
| 96 |
sources: str = None,
|
| 97 |
subtype: str = None,
|
|
|
|
| 103 |
|
| 104 |
Args:
|
| 105 |
query: The search query
|
| 106 |
+
vectorstore: Pre-initialized vector store instance
|
| 107 |
reports: List of specific report filenames to search within
|
| 108 |
sources: Source type to filter by
|
| 109 |
subtype: Document subtype to filter by
|
|
|
|
| 114 |
List of dictionaries with 'page_content' and 'metadata' keys
|
| 115 |
"""
|
| 116 |
try:
|
| 117 |
+
# Use the passed vector store instead of calling get_vectorstore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
k = top_k or RETRIEVER_TOP_K
|
| 119 |
+
|
| 120 |
+
# For Hugging Face Spaces, we pass the model name from config
|
| 121 |
search_kwargs = {
|
| 122 |
+
"model_name": config.get("embeddings", "MODEL_NAME")
|
|
|
|
| 123 |
}
|
| 124 |
|
| 125 |
+
# Note: Filtering is currently limited for Hugging Face Spaces
|
| 126 |
+
# as the API doesn't expose filtering capabilities
|
| 127 |
+
if any([reports, sources, subtype, year]):
|
| 128 |
+
logging.warning("Filtering not supported for Hugging Face Spaces API")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
# Perform retrieval
|
| 131 |
+
retrieved_docs = vectorstore.search(query, k, **search_kwargs)
|
| 132 |
|
| 133 |
logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
| 134 |
|
| 135 |
+
return retrieved_docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
except Exception as e:
|
| 138 |
logging.error(f"Error during retrieval: {str(e)}")
|
app/vectorstore_interface.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Any, Optional
|
| 3 |
+
from gradio_client import Client
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
class VectorStoreInterface(ABC):
|
| 9 |
+
"""Abstract interface for different vector store implementations."""
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
| 13 |
+
"""Search for similar documents."""
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
class HuggingFaceSpacesVectorStore(VectorStoreInterface):
|
| 17 |
+
"""Vector store implementation for Hugging Face Spaces with MCP endpoints."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, space_url: str, collection_name: str, hf_token: Optional[str] = None):
|
| 20 |
+
token = os.getenv("HF_TOKEN")
|
| 21 |
+
repo_id = space_url
|
| 22 |
+
|
| 23 |
+
logging.info(f"Connecting to Hugging Face Space: {repo_id}")
|
| 24 |
+
|
| 25 |
+
if token:
|
| 26 |
+
self.client = Client(repo_id, hf_token=token)
|
| 27 |
+
else:
|
| 28 |
+
self.client = Client(repo_id)
|
| 29 |
+
|
| 30 |
+
self.collection_name = collection_name
|
| 31 |
+
|
| 32 |
+
def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
| 33 |
+
"""Search using Hugging Face Spaces MCP API."""
|
| 34 |
+
try:
|
| 35 |
+
# Use the /search_text endpoint as documented in the API
|
| 36 |
+
result = self.client.predict(
|
| 37 |
+
query=query,
|
| 38 |
+
collection_name=self.collection_name,
|
| 39 |
+
model_name=kwargs.get('model_name'),
|
| 40 |
+
top_k=top_k,
|
| 41 |
+
api_name="/search_text"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
logging.info(f"Successfully retrieved {len(result) if result else 0} documents")
|
| 45 |
+
return result
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logging.error(f"Error searching Hugging Face Spaces: {str(e)}")
|
| 49 |
+
raise e
|
| 50 |
+
|
| 51 |
+
# class QdrantVectorStore(VectorStoreInterface):
|
| 52 |
+
# """Vector store implementation for direct Qdrant connection."""
|
| 53 |
+
# # needs to be generalized for other vector stores (or add a new class for each vector store)
|
| 54 |
+
# def __init__(self, host: str, port: int, collection_name: str, api_key: Optional[str] = None):
|
| 55 |
+
# from qdrant_client import QdrantClient
|
| 56 |
+
# from langchain_community.vectorstores import Qdrant
|
| 57 |
+
|
| 58 |
+
# self.client = QdrantClient(
|
| 59 |
+
# host=host,
|
| 60 |
+
# port=port,
|
| 61 |
+
# api_key=api_key
|
| 62 |
+
# )
|
| 63 |
+
# self.collection_name = collection_name
|
| 64 |
+
# # Embedding model not implemented
|
| 65 |
+
|
| 66 |
+
# def search(self, query: str, top_k: int, **kwargs) -> List[Dict[str, Any]]:
|
| 67 |
+
# """Search using direct Qdrant connection."""
|
| 68 |
+
# # Embedding model not implemented
|
| 69 |
+
# raise NotImplementedError("Direct Qdrant search needs embedding model configuration")
|
| 70 |
+
|
| 71 |
+
def create_vectorstore(config: Any) -> VectorStoreInterface:
|
| 72 |
+
"""Factory function to create appropriate vector store based on configuration."""
|
| 73 |
+
vectorstore_type = config.get("vectorstore", "TYPE")
|
| 74 |
+
|
| 75 |
+
if vectorstore_type.lower() == "huggingface_spaces":
|
| 76 |
+
space_url = config.get("vectorstore", "SPACE_URL")
|
| 77 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
| 78 |
+
hf_token = config.get("vectorstore", "HF_TOKEN", fallback=None)
|
| 79 |
+
return HuggingFaceSpacesVectorStore(space_url, collection_name, hf_token)
|
| 80 |
+
|
| 81 |
+
elif vectorstore_type.lower() == "qdrant":
|
| 82 |
+
host = config.get("vectorstore", "HOST")
|
| 83 |
+
port = int(config.get("vectorstore", "PORT"))
|
| 84 |
+
collection_name = config.get("vectorstore", "COLLECTION_NAME")
|
| 85 |
+
api_key = config.get("vectorstore", "API_KEY", fallback=None)
|
| 86 |
+
return QdrantVectorStore(host, port, collection_name, api_key)
|
| 87 |
+
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError(f"Unsupported vector store type: {vectorstore_type}")
|
params.cfg
CHANGED
|
@@ -3,11 +3,15 @@ TOP_K = 10
|
|
| 3 |
SCORE_THRESHOLD = 0.6
|
| 4 |
|
| 5 |
[vectorstore]
|
| 6 |
-
TYPE =
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
[embeddings]
|
| 13 |
MODEL_NAME = BAAI/bge-m3
|
|
|
|
| 3 |
SCORE_THRESHOLD = 0.6
|
| 4 |
|
| 5 |
[vectorstore]
|
| 6 |
+
TYPE = huggingface_spaces
|
| 7 |
+
SPACE_URL = GIZ/audit_data
|
| 8 |
+
COLLECTION_NAME = docling
|
| 9 |
+
# For future direct Qdrant usage:
|
| 10 |
+
# TYPE = qdrant
|
| 11 |
+
# HOST = ip address
|
| 12 |
+
# PORT = 6333
|
| 13 |
+
# COLLECTION_NAME = "collection name"
|
| 14 |
+
# API_KEY = api key for source
|
| 15 |
|
| 16 |
[embeddings]
|
| 17 |
MODEL_NAME = BAAI/bge-m3
|
requirements.txt
CHANGED
|
@@ -2,4 +2,6 @@ gradio[mcp]
|
|
| 2 |
langchain
|
| 3 |
langchain-community
|
| 4 |
qdrant-client
|
| 5 |
-
sentence-transformers
|
|
|
|
|
|
|
|
|
| 2 |
langchain
|
| 3 |
langchain-community
|
| 4 |
qdrant-client
|
| 5 |
+
sentence-transformers
|
| 6 |
+
gradio_client>=0.10.0
|
| 7 |
+
huggingface_hub>=0.20.0
|