Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import os | |
| import sys | |
| import logging | |
| from pathlib import Path | |
| import json | |
| import hashlib | |
| from datetime import datetime | |
| import threading | |
| import queue | |
| from typing import List, Dict, Any, Tuple, Optional | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Importing necessary libraries | |
| import torch | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| import gradio as gr | |
| from openai import OpenAI | |
| import google.generativeai as genai | |
| # Configuration class | |
| class Config: | |
| """Configuration for vector store and RAG""" | |
| def __init__(self, | |
| local_dir: str = "./chroma_data", | |
| batch_size: int = 20, | |
| max_workers: int = 4, | |
| embedding_model: str = "all-MiniLM-L6-v2", | |
| collection_name: str = "markdown_docs"): | |
| self.local_dir = local_dir | |
| self.batch_size = batch_size | |
| self.max_workers = max_workers | |
| self.checkpoint_file = Path(local_dir) / "checkpoint.json" | |
| self.embedding_model = embedding_model | |
| self.collection_name = collection_name | |
| # Create local directory for checkpoints and Chroma | |
| Path(local_dir).mkdir(parents=True, exist_ok=True) | |
| # Embedding engine | |
| class EmbeddingEngine: | |
| """Handle embeddings with a lightweight model""" | |
| def __init__(self, model_name="all-MiniLM-L6-v2"): | |
| # Use GPU if available | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {self.device}") | |
| # Try multiple model options in order of preference | |
| model_options = [ | |
| model_name, | |
| "all-MiniLM-L6-v2", | |
| "paraphrase-MiniLM-L3-v2", | |
| "all-mpnet-base-v2" # Higher quality but larger model | |
| ] | |
| self.model = None | |
| # Try each model in order until one works | |
| for model_option in model_options: | |
| try: | |
| logger.info(f"Attempting to load model: {model_option}") | |
| self.model = SentenceTransformer(model_option) | |
| # Move model to device | |
| self.model.to(self.device) | |
| logger.info(f"Successfully loaded model: {model_option}") | |
| self.model_name = model_option | |
| self.vector_size = self.model.get_sentence_embedding_dimension() | |
| break | |
| except Exception as e: | |
| logger.warning(f"Failed to load model {model_option}: {str(e)}") | |
| if self.model is None: | |
| logger.error("Failed to load any embedding model. Exiting.") | |
| sys.exit(1) | |
| def encode(self, text, batch_size=32): | |
| """Get embedding for a text or list of texts""" | |
| # Handle single text | |
| if isinstance(text, str): | |
| texts = [text] | |
| else: | |
| texts = text | |
| # Truncate texts if necessary to avoid tokenization issues | |
| truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts] | |
| # Generate embeddings | |
| try: | |
| embeddings = self.model.encode(truncated_texts, batch_size=batch_size, | |
| show_progress_bar=False, convert_to_numpy=True) | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Error generating embeddings: {e}") | |
| # Return zero embeddings as fallback | |
| return np.zeros((len(truncated_texts), self.vector_size)) | |
| class VectorStoreManager: | |
| """Manage Chroma vector store operations - upload, query, etc.""" | |
| def __init__(self, config: Config): | |
| self.config = config | |
| # Initialize Chroma client (local persistence) | |
| logger.info(f"Initializing Chroma at {config.local_dir}") | |
| self.client = chromadb.PersistentClient(path=config.local_dir) | |
| # Get or create collection | |
| try: | |
| # Initialize embedding model | |
| logger.info("Loading embedding model...") | |
| self.embedding_engine = EmbeddingEngine(config.embedding_model) | |
| logger.info(f"Using model: {self.embedding_engine.model_name}") | |
| # Create embedding function | |
| sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name=self.embedding_engine.model_name | |
| ) | |
| # Try to get existing collection | |
| try: | |
| self.collection = self.client.get_collection( | |
| name=config.collection_name, | |
| embedding_function=sentence_transformer_ef | |
| ) | |
| logger.info(f"Using existing collection: {config.collection_name}") | |
| except: | |
| # Create new collection if it doesn't exist | |
| self.collection = self.client.create_collection( | |
| name=config.collection_name, | |
| embedding_function=sentence_transformer_ef, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| logger.info(f"Created new collection: {config.collection_name}") | |
| except Exception as e: | |
| logger.error(f"Error initializing Chroma collection: {e}") | |
| sys.exit(1) | |
| def query(self, query_text: str, n_results: int = 5) -> List[Dict]: | |
| """ | |
| Query the vector store with a text query | |
| """ | |
| try: | |
| # Query the collection | |
| search_results = self.collection.query( | |
| query_texts=[query_text], | |
| n_results=n_results, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| # Format results | |
| results = [] | |
| if search_results["documents"] and len(search_results["documents"][0]) > 0: | |
| for i in range(len(search_results["documents"][0])): | |
| results.append({ | |
| 'document': search_results["documents"][0][i], | |
| 'metadata': search_results["metadatas"][0][i], | |
| 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error querying collection: {e}") | |
| return [] | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get statistics about the vector store""" | |
| stats = {} | |
| try: | |
| # Get collection count | |
| collection_info = self.collection.count() | |
| stats['total_documents'] = collection_info | |
| # Estimate unique files - with no chunking, each document is a file | |
| stats['unique_files'] = collection_info | |
| except Exception as e: | |
| logger.error(f"Error getting statistics: {e}") | |
| stats['error'] = str(e) | |
| return stats | |
| class RAGSystem: | |
| """Retrieval-Augmented Generation with multiple LLM providers""" | |
| def __init__(self, vector_store: VectorStoreManager): | |
| self.vector_store = vector_store | |
| self.openai_client = None | |
| self.gemini_configured = False | |
| def setup_openai(self, api_key: str): | |
| """Set up OpenAI client with API key""" | |
| try: | |
| self.openai_client = OpenAI(api_key=api_key) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error initializing OpenAI client: {e}") | |
| return False | |
| def setup_gemini(self, api_key: str): | |
| """Set up Gemini with API key""" | |
| try: | |
| genai.configure(api_key=api_key) | |
| self.gemini_configured = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error configuring Gemini: {e}") | |
| return False | |
| def format_context(self, documents: List[Dict]) -> str: | |
| """Format retrieved documents into context for the LLM""" | |
| if not documents: | |
| return "No relevant documents found." | |
| context_parts = [] | |
| for i, doc in enumerate(documents): | |
| metadata = doc['metadata'] | |
| title = metadata.get('title', metadata.get('filename', 'Unknown document')) | |
| # For readability, limit length of context document | |
| doc_text = doc['document'] | |
| if len(doc_text) > 10000: # Limit long documents in context | |
| doc_text = doc_text[:10000] + "... [Document truncated for context]" | |
| context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n") | |
| return "\n".join(context_parts) | |
| def generate_response_openai(self, query: str, context: str) -> str: | |
| """Generate a response using OpenAI model with context""" | |
| if not self.openai_client: | |
| return "Error: OpenAI API key not configured. Please enter an API key in the settings tab." | |
| system_prompt = """ | |
| You are a helpful assistant that answers questions based on the context provided. | |
| Use the information from the context to answer the user's question. | |
| If the context doesn't contain the information needed, say so clearly. | |
| Always cite the specific sections from the context that you used in your answer. | |
| """ | |
| try: | |
| response = self.openai_client.chat.completions.create( | |
| model="gpt-4o-mini", # Use GPT-4o mini | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"} | |
| ], | |
| temperature=0.3, # Lower temperature for more factual responses | |
| max_tokens=1000, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error generating response with OpenAI: {e}") | |
| return f"Error generating response with OpenAI: {str(e)}" | |
| def generate_response_gemini(self, query: str, context: str) -> str: | |
| """Generate a response using Gemini with context""" | |
| if not self.gemini_configured: | |
| return "Error: Google AI API key not configured. Please enter an API key in the settings tab." | |
| prompt = f""" | |
| You are a helpful assistant that answers questions based on the context provided. | |
| Use the information from the context to answer the user's question. | |
| If the context doesn't contain the information needed, say so clearly. | |
| Always cite the specific sections from the context that you used in your answer. | |
| Context: | |
| {context} | |
| Question: {query} | |
| """ | |
| try: | |
| model = genai.GenerativeModel('gemini-1.5-flash') | |
| response = model.generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| logger.error(f"Error generating response with Gemini: {e}") | |
| return f"Error generating response with Gemini: {str(e)}" | |
| def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str: | |
| """Retrieve relevant documents and generate a response using the specified model""" | |
| # Query vector store | |
| documents = self.vector_store.query(query, n_results=n_results) | |
| if not documents: | |
| return "No relevant documents found to answer your question." | |
| # Format context | |
| context = self.format_context(documents) | |
| # Generate response with the appropriate model | |
| if model == "openai": | |
| return self.generate_response_openai(query, context) | |
| elif model == "gemini": | |
| return self.generate_response_gemini(query, context) | |
| else: | |
| return f"Unknown model: {model}" | |
| def rag_chat(query, n_results, model_choice, rag_system): | |
| """Function to handle RAG chat queries""" | |
| return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice) | |
| def simple_query(query, n_results, vector_store): | |
| """Function to handle simple vector store queries""" | |
| results = vector_store.query(query, n_results=int(n_results)) | |
| # Format results for display | |
| formatted = [] | |
| for i, res in enumerate(results): | |
| metadata = res['metadata'] | |
| title = metadata.get('title', metadata.get('filename', 'Unknown')) | |
| # Limit preview text for display | |
| preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document'] | |
| formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n" | |
| f"**Source:** {title}\n\n" | |
| f"**Content:**\n{preview}\n\n" | |
| f"---\n") | |
| return "\n".join(formatted) if formatted else "No results found." | |
| def get_db_stats(vector_store): | |
| """Function to get vector store statistics""" | |
| stats = vector_store.get_statistics() | |
| return (f"Total documents: {stats.get('total_documents', 0)}\n" | |
| f"Unique files: {stats.get('unique_files', 0)}") | |
| def update_api_keys(openai_key, gemini_key, rag_system): | |
| """Update API keys for the RAG system""" | |
| success_msg = [] | |
| if openai_key: | |
| if rag_system.setup_openai(openai_key): | |
| success_msg.append("✅ OpenAI API key configured successfully") | |
| else: | |
| success_msg.append("❌ Failed to configure OpenAI API key") | |
| if gemini_key: | |
| if rag_system.setup_gemini(gemini_key): | |
| success_msg.append("✅ Google AI API key configured successfully") | |
| else: | |
| success_msg.append("❌ Failed to configure Google AI API key") | |
| if not success_msg: | |
| return "Please enter at least one API key" | |
| return "\n".join(success_msg) | |
| # Main function to run the application | |
| def main(): | |
| # Set up paths for existing Chroma database | |
| chroma_dir = Path("./chroma_data") | |
| # Initialize the system | |
| config = Config( | |
| local_dir=str(chroma_dir), | |
| collection_name="markdown_docs" | |
| ) | |
| # Initialize vector store manager with existing collection | |
| vector_store = VectorStoreManager(config) | |
| # Initialize RAG system without API keys initially | |
| rag_system = RAGSystem(vector_store) | |
| # Define Gradio app | |
| def rag_chat_wrapper(query, n_results, model_choice): | |
| return rag_chat(query, n_results, model_choice, rag_system) | |
| def simple_query_wrapper(query, n_results): | |
| return simple_query(query, n_results, vector_store) | |
| def update_api_keys_wrapper(openai_key, gemini_key): | |
| return update_api_keys(openai_key, gemini_key, rag_system) | |
| # Create the Gradio interface | |
| with gr.Blocks(title="Markdown RAG System") as app: | |
| gr.Markdown("# RAG System with Multiple LLM Providers") | |
| with gr.Tab("Chat with Documents"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...") | |
| num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve") | |
| model_choice = gr.Radio( | |
| choices=["openai", "gemini"], | |
| value="openai", | |
| label="Choose LLM Provider", | |
| info="Select which model to use for generating answers" | |
| ) | |
| query_button = gr.Button("Ask", variant="primary") | |
| with gr.Column(scale=7): | |
| response_output = gr.Markdown(label="Response") | |
| # Database stats | |
| stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store)) | |
| refresh_button = gr.Button("Refresh Statistics") | |
| with gr.Tab("Document Search"): | |
| search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...") | |
| search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results") | |
| search_button = gr.Button("Search", variant="primary") | |
| search_output = gr.Markdown(label="Search Results") | |
| with gr.Tab("Settings"): | |
| gr.Markdown(""" | |
| ## API Keys Configuration | |
| This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses. | |
| You need to provide at least one API key to use the chat functionality. | |
| """) | |
| openai_key_input = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="Enter your OpenAI API key here...", | |
| type="password" | |
| ) | |
| gemini_key_input = gr.Textbox( | |
| label="Google AI API Key", | |
| placeholder="Enter your Google AI API key here...", | |
| type="password" | |
| ) | |
| save_keys_button = gr.Button("Save API Keys", variant="primary") | |
| api_status = gr.Markdown("") | |
| # Set up events | |
| query_button.click( | |
| fn=rag_chat_wrapper, | |
| inputs=[query_input, num_results, model_choice], | |
| outputs=response_output | |
| ) | |
| refresh_button.click( | |
| fn=lambda: get_db_stats(vector_store), | |
| inputs=None, | |
| outputs=stats_display | |
| ) | |
| search_button.click( | |
| fn=simple_query_wrapper, | |
| inputs=[search_input, search_num], | |
| outputs=search_output | |
| ) | |
| save_keys_button.click( | |
| fn=update_api_keys_wrapper, | |
| inputs=[openai_key_input, gemini_key_input], | |
| outputs=api_status | |
| ) | |
| # Launch the interface | |
| app.launch() | |
| if __name__ == "__main__": | |
| main() | 
