Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import logging | |
from pathlib import Path | |
import json | |
import hashlib | |
from datetime import datetime | |
import threading | |
import queue | |
from typing import List, Dict, Any, Tuple, Optional | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Importing necessary libraries | |
import torch | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
from chromadb.utils import embedding_functions | |
import gradio as gr | |
from openai import OpenAI | |
import google.generativeai as genai | |
# Configuration class | |
class Config: | |
"""Configuration for vector store and RAG""" | |
def __init__(self, | |
local_dir: str = "./chroma_data", | |
batch_size: int = 20, | |
max_workers: int = 4, | |
embedding_model: str = "all-MiniLM-L6-v2", | |
collection_name: str = "markdown_docs"): | |
self.local_dir = local_dir | |
self.batch_size = batch_size | |
self.max_workers = max_workers | |
self.checkpoint_file = Path(local_dir) / "checkpoint.json" | |
self.embedding_model = embedding_model | |
self.collection_name = collection_name | |
# Create local directory for checkpoints and Chroma | |
Path(local_dir).mkdir(parents=True, exist_ok=True) | |
# Embedding engine | |
class EmbeddingEngine: | |
"""Handle embeddings with a lightweight model""" | |
def __init__(self, model_name="all-MiniLM-L6-v2"): | |
# Use GPU if available | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {self.device}") | |
# Try multiple model options in order of preference | |
model_options = [ | |
model_name, | |
"all-MiniLM-L6-v2", | |
"paraphrase-MiniLM-L3-v2", | |
"all-mpnet-base-v2" # Higher quality but larger model | |
] | |
self.model = None | |
# Try each model in order until one works | |
for model_option in model_options: | |
try: | |
logger.info(f"Attempting to load model: {model_option}") | |
self.model = SentenceTransformer(model_option) | |
# Move model to device | |
self.model.to(self.device) | |
logger.info(f"Successfully loaded model: {model_option}") | |
self.model_name = model_option | |
self.vector_size = self.model.get_sentence_embedding_dimension() | |
break | |
except Exception as e: | |
logger.warning(f"Failed to load model {model_option}: {str(e)}") | |
if self.model is None: | |
logger.error("Failed to load any embedding model. Exiting.") | |
sys.exit(1) | |
def encode(self, text, batch_size=32): | |
"""Get embedding for a text or list of texts""" | |
# Handle single text | |
if isinstance(text, str): | |
texts = [text] | |
else: | |
texts = text | |
# Truncate texts if necessary to avoid tokenization issues | |
truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts] | |
# Generate embeddings | |
try: | |
embeddings = self.model.encode(truncated_texts, batch_size=batch_size, | |
show_progress_bar=False, convert_to_numpy=True) | |
return embeddings | |
except Exception as e: | |
logger.error(f"Error generating embeddings: {e}") | |
# Return zero embeddings as fallback | |
return np.zeros((len(truncated_texts), self.vector_size)) | |
class VectorStoreManager: | |
"""Manage Chroma vector store operations - upload, query, etc.""" | |
def __init__(self, config: Config): | |
self.config = config | |
# Initialize Chroma client (local persistence) | |
logger.info(f"Initializing Chroma at {config.local_dir}") | |
self.client = chromadb.PersistentClient(path=config.local_dir) | |
# Get or create collection | |
try: | |
# Initialize embedding model | |
logger.info("Loading embedding model...") | |
self.embedding_engine = EmbeddingEngine(config.embedding_model) | |
logger.info(f"Using model: {self.embedding_engine.model_name}") | |
# Create embedding function | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name=self.embedding_engine.model_name | |
) | |
# Try to get existing collection | |
try: | |
self.collection = self.client.get_collection( | |
name=config.collection_name, | |
embedding_function=sentence_transformer_ef | |
) | |
logger.info(f"Using existing collection: {config.collection_name}") | |
except: | |
# Create new collection if it doesn't exist | |
self.collection = self.client.create_collection( | |
name=config.collection_name, | |
embedding_function=sentence_transformer_ef, | |
metadata={"hnsw:space": "cosine"} | |
) | |
logger.info(f"Created new collection: {config.collection_name}") | |
except Exception as e: | |
logger.error(f"Error initializing Chroma collection: {e}") | |
sys.exit(1) | |
def query(self, query_text: str, n_results: int = 5) -> List[Dict]: | |
""" | |
Query the vector store with a text query | |
""" | |
try: | |
# Query the collection | |
search_results = self.collection.query( | |
query_texts=[query_text], | |
n_results=n_results, | |
include=["documents", "metadatas", "distances"] | |
) | |
# Format results | |
results = [] | |
if search_results["documents"] and len(search_results["documents"][0]) > 0: | |
for i in range(len(search_results["documents"][0])): | |
results.append({ | |
'document': search_results["documents"][0][i], | |
'metadata': search_results["metadatas"][0][i], | |
'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity | |
}) | |
return results | |
except Exception as e: | |
logger.error(f"Error querying collection: {e}") | |
return [] | |
def get_statistics(self) -> Dict[str, Any]: | |
"""Get statistics about the vector store""" | |
stats = {} | |
try: | |
# Get collection count | |
collection_info = self.collection.count() | |
stats['total_documents'] = collection_info | |
# Estimate unique files - with no chunking, each document is a file | |
stats['unique_files'] = collection_info | |
except Exception as e: | |
logger.error(f"Error getting statistics: {e}") | |
stats['error'] = str(e) | |
return stats | |
class RAGSystem: | |
"""Retrieval-Augmented Generation with multiple LLM providers""" | |
def __init__(self, vector_store: VectorStoreManager): | |
self.vector_store = vector_store | |
self.openai_client = None | |
self.gemini_configured = False | |
def setup_openai(self, api_key: str): | |
"""Set up OpenAI client with API key""" | |
try: | |
self.openai_client = OpenAI(api_key=api_key) | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing OpenAI client: {e}") | |
return False | |
def setup_gemini(self, api_key: str): | |
"""Set up Gemini with API key""" | |
try: | |
genai.configure(api_key=api_key) | |
self.gemini_configured = True | |
return True | |
except Exception as e: | |
logger.error(f"Error configuring Gemini: {e}") | |
return False | |
def format_context(self, documents: List[Dict]) -> str: | |
"""Format retrieved documents into context for the LLM""" | |
if not documents: | |
return "No relevant documents found." | |
context_parts = [] | |
for i, doc in enumerate(documents): | |
metadata = doc['metadata'] | |
title = metadata.get('title', metadata.get('filename', 'Unknown document')) | |
# For readability, limit length of context document | |
doc_text = doc['document'] | |
if len(doc_text) > 10000: # Limit long documents in context | |
doc_text = doc_text[:10000] + "... [Document truncated for context]" | |
context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n") | |
return "\n".join(context_parts) | |
def generate_response_openai(self, query: str, context: str) -> str: | |
"""Generate a response using OpenAI model with context""" | |
if not self.openai_client: | |
return "Error: OpenAI API key not configured. Please enter an API key in the settings tab." | |
system_prompt = """ | |
You are a helpful assistant that answers questions based on the context provided. | |
Use the information from the context to answer the user's question. | |
If the context doesn't contain the information needed, say so clearly. | |
Always cite the specific sections from the context that you used in your answer. | |
""" | |
try: | |
response = self.openai_client.chat.completions.create( | |
model="gpt-4o-mini", # Use GPT-4o mini | |
messages=[ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"} | |
], | |
temperature=0.3, # Lower temperature for more factual responses | |
max_tokens=1000, | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
logger.error(f"Error generating response with OpenAI: {e}") | |
return f"Error generating response with OpenAI: {str(e)}" | |
def generate_response_gemini(self, query: str, context: str) -> str: | |
"""Generate a response using Gemini with context""" | |
if not self.gemini_configured: | |
return "Error: Google AI API key not configured. Please enter an API key in the settings tab." | |
prompt = f""" | |
You are a helpful assistant that answers questions based on the context provided. | |
Use the information from the context to answer the user's question. | |
If the context doesn't contain the information needed, say so clearly. | |
Always cite the specific sections from the context that you used in your answer. | |
Context: | |
{context} | |
Question: {query} | |
""" | |
try: | |
model = genai.GenerativeModel('gemini-1.5-flash') | |
response = model.generate_content(prompt) | |
return response.text | |
except Exception as e: | |
logger.error(f"Error generating response with Gemini: {e}") | |
return f"Error generating response with Gemini: {str(e)}" | |
def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str: | |
"""Retrieve relevant documents and generate a response using the specified model""" | |
# Query vector store | |
documents = self.vector_store.query(query, n_results=n_results) | |
if not documents: | |
return "No relevant documents found to answer your question." | |
# Format context | |
context = self.format_context(documents) | |
# Generate response with the appropriate model | |
if model == "openai": | |
return self.generate_response_openai(query, context) | |
elif model == "gemini": | |
return self.generate_response_gemini(query, context) | |
else: | |
return f"Unknown model: {model}" | |
def rag_chat(query, n_results, model_choice, rag_system): | |
"""Function to handle RAG chat queries""" | |
return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice) | |
def simple_query(query, n_results, vector_store): | |
"""Function to handle simple vector store queries""" | |
results = vector_store.query(query, n_results=int(n_results)) | |
# Format results for display | |
formatted = [] | |
for i, res in enumerate(results): | |
metadata = res['metadata'] | |
title = metadata.get('title', metadata.get('filename', 'Unknown')) | |
# Limit preview text for display | |
preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document'] | |
formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n" | |
f"**Source:** {title}\n\n" | |
f"**Content:**\n{preview}\n\n" | |
f"---\n") | |
return "\n".join(formatted) if formatted else "No results found." | |
def get_db_stats(vector_store): | |
"""Function to get vector store statistics""" | |
stats = vector_store.get_statistics() | |
return (f"Total documents: {stats.get('total_documents', 0)}\n" | |
f"Unique files: {stats.get('unique_files', 0)}") | |
def update_api_keys(openai_key, gemini_key, rag_system): | |
"""Update API keys for the RAG system""" | |
success_msg = [] | |
if openai_key: | |
if rag_system.setup_openai(openai_key): | |
success_msg.append("β OpenAI API key configured successfully") | |
else: | |
success_msg.append("β Failed to configure OpenAI API key") | |
if gemini_key: | |
if rag_system.setup_gemini(gemini_key): | |
success_msg.append("β Google AI API key configured successfully") | |
else: | |
success_msg.append("β Failed to configure Google AI API key") | |
if not success_msg: | |
return "Please enter at least one API key" | |
return "\n".join(success_msg) | |
# Main function to run the application | |
def main(): | |
# Set up paths for existing Chroma database | |
chroma_dir = Path("./chroma_data") | |
# Initialize the system | |
config = Config( | |
local_dir=str(chroma_dir), | |
collection_name="markdown_docs" | |
) | |
# Initialize vector store manager with existing collection | |
vector_store = VectorStoreManager(config) | |
# Initialize RAG system without API keys initially | |
rag_system = RAGSystem(vector_store) | |
# Define Gradio app | |
def rag_chat_wrapper(query, n_results, model_choice): | |
return rag_chat(query, n_results, model_choice, rag_system) | |
def simple_query_wrapper(query, n_results): | |
return simple_query(query, n_results, vector_store) | |
def update_api_keys_wrapper(openai_key, gemini_key): | |
return update_api_keys(openai_key, gemini_key, rag_system) | |
# Create the Gradio interface | |
with gr.Blocks(title="Markdown RAG System") as app: | |
gr.Markdown("# RAG System with Multiple LLM Providers") | |
with gr.Tab("Chat with Documents"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...") | |
num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve") | |
model_choice = gr.Radio( | |
choices=["openai", "gemini"], | |
value="openai", | |
label="Choose LLM Provider", | |
info="Select which model to use for generating answers" | |
) | |
query_button = gr.Button("Ask", variant="primary") | |
with gr.Column(scale=7): | |
response_output = gr.Markdown(label="Response") | |
# Database stats | |
stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store)) | |
refresh_button = gr.Button("Refresh Statistics") | |
with gr.Tab("Document Search"): | |
search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...") | |
search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results") | |
search_button = gr.Button("Search", variant="primary") | |
search_output = gr.Markdown(label="Search Results") | |
with gr.Tab("Settings"): | |
gr.Markdown(""" | |
## API Keys Configuration | |
This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses. | |
You need to provide at least one API key to use the chat functionality. | |
""") | |
openai_key_input = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="Enter your OpenAI API key here...", | |
type="password" | |
) | |
gemini_key_input = gr.Textbox( | |
label="Google AI API Key", | |
placeholder="Enter your Google AI API key here...", | |
type="password" | |
) | |
save_keys_button = gr.Button("Save API Keys", variant="primary") | |
api_status = gr.Markdown("") | |
# Set up events | |
query_button.click( | |
fn=rag_chat_wrapper, | |
inputs=[query_input, num_results, model_choice], | |
outputs=response_output | |
) | |
refresh_button.click( | |
fn=lambda: get_db_stats(vector_store), | |
inputs=None, | |
outputs=stats_display | |
) | |
search_button.click( | |
fn=simple_query_wrapper, | |
inputs=[search_input, search_num], | |
outputs=search_output | |
) | |
save_keys_button.click( | |
fn=update_api_keys_wrapper, | |
inputs=[openai_key_input, gemini_key_input], | |
outputs=api_status | |
) | |
# Launch the interface | |
app.launch() | |
if __name__ == "__main__": | |
main() |