Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,64 +1,467 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            from  | 
|  | |
| 3 |  | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
            """
         | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 |  | 
| 10 | 
            -
            def  | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
             | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 19 |  | 
| 20 | 
            -
             | 
| 21 | 
            -
             | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
| 24 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 |  | 
| 26 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 27 |  | 
| 28 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 |  | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
                    stream=True,
         | 
| 34 | 
            -
                    temperature=temperature,
         | 
| 35 | 
            -
                    top_p=top_p,
         | 
| 36 | 
            -
                ):
         | 
| 37 | 
            -
                    token = message.choices[0].delta.content
         | 
| 38 |  | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 41 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 42 |  | 
| 43 | 
            -
             | 
| 44 | 
            -
             | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
            -
                 | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 | 
            -
                     | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
                         | 
| 56 | 
            -
             | 
| 57 | 
            -
                         | 
| 58 | 
            -
             | 
| 59 | 
            -
                 | 
| 60 | 
            -
             | 
|  | |
|  | |
| 61 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 62 |  | 
| 63 | 
             
            if __name__ == "__main__":
         | 
| 64 | 
            -
                 | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            import hashlib
         | 
| 7 | 
            +
            from datetime import datetime
         | 
| 8 | 
            +
            import threading
         | 
| 9 | 
            +
            import queue
         | 
| 10 | 
            +
            from typing import List, Dict, Any, Tuple, Optional
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Configure logging
         | 
| 13 | 
            +
            logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
         | 
| 14 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Importing necessary libraries
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            import numpy as np
         | 
| 19 | 
            +
            from sentence_transformers import SentenceTransformer
         | 
| 20 | 
            +
            import chromadb
         | 
| 21 | 
            +
            from chromadb.utils import embedding_functions
         | 
| 22 | 
             
            import gradio as gr
         | 
| 23 | 
            +
            from openai import OpenAI
         | 
| 24 | 
            +
            import google.generativeai as genai
         | 
| 25 |  | 
| 26 | 
            +
            # Configuration class
         | 
| 27 | 
            +
            class Config:
         | 
| 28 | 
            +
                """Configuration for vector store and RAG"""
         | 
| 29 | 
            +
                def __init__(self, 
         | 
| 30 | 
            +
                             local_dir: str = "./chroma_data",
         | 
| 31 | 
            +
                             batch_size: int = 20,
         | 
| 32 | 
            +
                             max_workers: int = 4,
         | 
| 33 | 
            +
                             embedding_model: str = "all-MiniLM-L6-v2",
         | 
| 34 | 
            +
                             collection_name: str = "markdown_docs"):
         | 
| 35 | 
            +
                    self.local_dir = local_dir
         | 
| 36 | 
            +
                    self.batch_size = batch_size
         | 
| 37 | 
            +
                    self.max_workers = max_workers
         | 
| 38 | 
            +
                    self.checkpoint_file = Path(local_dir) / "checkpoint.json"
         | 
| 39 | 
            +
                    self.embedding_model = embedding_model
         | 
| 40 | 
            +
                    self.collection_name = collection_name
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                    # Create local directory for checkpoints and Chroma
         | 
| 43 | 
            +
                    Path(local_dir).mkdir(parents=True, exist_ok=True)
         | 
| 44 |  | 
| 45 | 
            +
            # Embedding engine
         | 
| 46 | 
            +
            class EmbeddingEngine:
         | 
| 47 | 
            +
                """Handle embeddings with a lightweight model"""
         | 
| 48 | 
            +
                
         | 
| 49 | 
            +
                def __init__(self, model_name="all-MiniLM-L6-v2"):
         | 
| 50 | 
            +
                    # Use GPU if available
         | 
| 51 | 
            +
                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 52 | 
            +
                    logger.info(f"Using device: {self.device}")
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
                    # Try multiple model options in order of preference
         | 
| 55 | 
            +
                    model_options = [
         | 
| 56 | 
            +
                        model_name,
         | 
| 57 | 
            +
                        "all-MiniLM-L6-v2",
         | 
| 58 | 
            +
                        "paraphrase-MiniLM-L3-v2",
         | 
| 59 | 
            +
                        "all-mpnet-base-v2"  # Higher quality but larger model
         | 
| 60 | 
            +
                    ]
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    self.model = None
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    # Try each model in order until one works
         | 
| 65 | 
            +
                    for model_option in model_options:
         | 
| 66 | 
            +
                        try:
         | 
| 67 | 
            +
                            logger.info(f"Attempting to load model: {model_option}")
         | 
| 68 | 
            +
                            self.model = SentenceTransformer(model_option)
         | 
| 69 | 
            +
                            
         | 
| 70 | 
            +
                            # Move model to device
         | 
| 71 | 
            +
                            self.model.to(self.device)
         | 
| 72 | 
            +
                            
         | 
| 73 | 
            +
                            logger.info(f"Successfully loaded model: {model_option}")
         | 
| 74 | 
            +
                            self.model_name = model_option
         | 
| 75 | 
            +
                            self.vector_size = self.model.get_sentence_embedding_dimension()
         | 
| 76 | 
            +
                            break
         | 
| 77 | 
            +
                            
         | 
| 78 | 
            +
                        except Exception as e:
         | 
| 79 | 
            +
                            logger.warning(f"Failed to load model {model_option}: {str(e)}")
         | 
| 80 | 
            +
                    
         | 
| 81 | 
            +
                    if self.model is None:
         | 
| 82 | 
            +
                        logger.error("Failed to load any embedding model. Exiting.")
         | 
| 83 | 
            +
                        sys.exit(1)
         | 
| 84 |  | 
| 85 | 
            +
                def encode(self, text, batch_size=32):
         | 
| 86 | 
            +
                    """Get embedding for a text or list of texts"""
         | 
| 87 | 
            +
                    # Handle single text
         | 
| 88 | 
            +
                    if isinstance(text, str):
         | 
| 89 | 
            +
                        texts = [text]
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        texts = text
         | 
| 92 | 
            +
                        
         | 
| 93 | 
            +
                    # Truncate texts if necessary to avoid tokenization issues
         | 
| 94 | 
            +
                    truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts]
         | 
| 95 | 
            +
                    
         | 
| 96 | 
            +
                    # Generate embeddings
         | 
| 97 | 
            +
                    try:
         | 
| 98 | 
            +
                        embeddings = self.model.encode(truncated_texts, batch_size=batch_size, 
         | 
| 99 | 
            +
                                                     show_progress_bar=False, convert_to_numpy=True)
         | 
| 100 | 
            +
                        return embeddings
         | 
| 101 | 
            +
                    except Exception as e:
         | 
| 102 | 
            +
                        logger.error(f"Error generating embeddings: {e}")
         | 
| 103 | 
            +
                        # Return zero embeddings as fallback
         | 
| 104 | 
            +
                        return np.zeros((len(truncated_texts), self.vector_size))
         | 
| 105 |  | 
| 106 | 
            +
            class VectorStoreManager:
         | 
| 107 | 
            +
                """Manage Chroma vector store operations - upload, query, etc."""
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                def __init__(self, config: Config):
         | 
| 110 | 
            +
                    self.config = config
         | 
| 111 | 
            +
                        
         | 
| 112 | 
            +
                    # Initialize Chroma client (local persistence)
         | 
| 113 | 
            +
                    logger.info(f"Initializing Chroma at {config.local_dir}")
         | 
| 114 | 
            +
                    self.client = chromadb.PersistentClient(path=config.local_dir)
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    # Get or create collection
         | 
| 117 | 
            +
                    try:
         | 
| 118 | 
            +
                        # Initialize embedding model
         | 
| 119 | 
            +
                        logger.info("Loading embedding model...")
         | 
| 120 | 
            +
                        self.embedding_engine = EmbeddingEngine(config.embedding_model)
         | 
| 121 | 
            +
                        logger.info(f"Using model: {self.embedding_engine.model_name}")
         | 
| 122 | 
            +
                        
         | 
| 123 | 
            +
                        # Create embedding function
         | 
| 124 | 
            +
                        sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
         | 
| 125 | 
            +
                            model_name=self.embedding_engine.model_name
         | 
| 126 | 
            +
                        )
         | 
| 127 | 
            +
                        
         | 
| 128 | 
            +
                        # Try to get existing collection
         | 
| 129 | 
            +
                        try:
         | 
| 130 | 
            +
                            self.collection = self.client.get_collection(
         | 
| 131 | 
            +
                                name=config.collection_name,
         | 
| 132 | 
            +
                                embedding_function=sentence_transformer_ef
         | 
| 133 | 
            +
                            )
         | 
| 134 | 
            +
                            logger.info(f"Using existing collection: {config.collection_name}")
         | 
| 135 | 
            +
                        except:
         | 
| 136 | 
            +
                            # Create new collection if it doesn't exist
         | 
| 137 | 
            +
                            self.collection = self.client.create_collection(
         | 
| 138 | 
            +
                                name=config.collection_name,
         | 
| 139 | 
            +
                                embedding_function=sentence_transformer_ef,
         | 
| 140 | 
            +
                                metadata={"hnsw:space": "cosine"}
         | 
| 141 | 
            +
                            )
         | 
| 142 | 
            +
                            logger.info(f"Created new collection: {config.collection_name}")
         | 
| 143 | 
            +
                            
         | 
| 144 | 
            +
                    except Exception as e:
         | 
| 145 | 
            +
                        logger.error(f"Error initializing Chroma collection: {e}")
         | 
| 146 | 
            +
                        sys.exit(1)
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
         | 
| 149 | 
            +
                    """
         | 
| 150 | 
            +
                    Query the vector store with a text query
         | 
| 151 | 
            +
                    """
         | 
| 152 | 
            +
                    try:
         | 
| 153 | 
            +
                        # Query the collection
         | 
| 154 | 
            +
                        search_results = self.collection.query(
         | 
| 155 | 
            +
                            query_texts=[query_text],
         | 
| 156 | 
            +
                            n_results=n_results,
         | 
| 157 | 
            +
                            include=["documents", "metadatas", "distances"]
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
                        
         | 
| 160 | 
            +
                        # Format results
         | 
| 161 | 
            +
                        results = []
         | 
| 162 | 
            +
                        if search_results["documents"] and len(search_results["documents"][0]) > 0:
         | 
| 163 | 
            +
                            for i in range(len(search_results["documents"][0])):
         | 
| 164 | 
            +
                                results.append({
         | 
| 165 | 
            +
                                    'document': search_results["documents"][0][i],
         | 
| 166 | 
            +
                                    'metadata': search_results["metadatas"][0][i],
         | 
| 167 | 
            +
                                    'score': 1.0 - search_results["distances"][0][i]  # Convert distance to similarity
         | 
| 168 | 
            +
                                })
         | 
| 169 | 
            +
                        
         | 
| 170 | 
            +
                        return results
         | 
| 171 | 
            +
                    except Exception as e:
         | 
| 172 | 
            +
                        logger.error(f"Error querying collection: {e}")
         | 
| 173 | 
            +
                        return []
         | 
| 174 |  | 
| 175 | 
            +
                def get_statistics(self) -> Dict[str, Any]:
         | 
| 176 | 
            +
                    """Get statistics about the vector store"""
         | 
| 177 | 
            +
                    stats = {}
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    try:
         | 
| 180 | 
            +
                        # Get collection count
         | 
| 181 | 
            +
                        collection_info = self.collection.count()
         | 
| 182 | 
            +
                        stats['total_documents'] = collection_info
         | 
| 183 | 
            +
                        
         | 
| 184 | 
            +
                        # Estimate unique files - with no chunking, each document is a file
         | 
| 185 | 
            +
                        stats['unique_files'] = collection_info
         | 
| 186 | 
            +
                    except Exception as e:
         | 
| 187 | 
            +
                        logger.error(f"Error getting statistics: {e}")
         | 
| 188 | 
            +
                        stats['error'] = str(e)
         | 
| 189 | 
            +
                    
         | 
| 190 | 
            +
                    return stats
         | 
| 191 |  | 
| 192 | 
            +
            class RAGSystem:
         | 
| 193 | 
            +
                """Retrieval-Augmented Generation with multiple LLM providers"""
         | 
| 194 | 
            +
                
         | 
| 195 | 
            +
                def __init__(self, vector_store: VectorStoreManager):
         | 
| 196 | 
            +
                    self.vector_store = vector_store
         | 
| 197 | 
            +
                    self.openai_client = None
         | 
| 198 | 
            +
                    self.gemini_configured = False
         | 
| 199 | 
            +
                
         | 
| 200 | 
            +
                def setup_openai(self, api_key: str):
         | 
| 201 | 
            +
                    """Set up OpenAI client with API key"""
         | 
| 202 | 
            +
                    try:
         | 
| 203 | 
            +
                        self.openai_client = OpenAI(api_key=api_key)
         | 
| 204 | 
            +
                        return True
         | 
| 205 | 
            +
                    except Exception as e:
         | 
| 206 | 
            +
                        logger.error(f"Error initializing OpenAI client: {e}")
         | 
| 207 | 
            +
                        return False
         | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                def setup_gemini(self, api_key: str):
         | 
| 210 | 
            +
                    """Set up Gemini with API key"""
         | 
| 211 | 
            +
                    try:
         | 
| 212 | 
            +
                        genai.configure(api_key=api_key)
         | 
| 213 | 
            +
                        self.gemini_configured = True
         | 
| 214 | 
            +
                        return True
         | 
| 215 | 
            +
                    except Exception as e:
         | 
| 216 | 
            +
                        logger.error(f"Error configuring Gemini: {e}")
         | 
| 217 | 
            +
                        return False
         | 
| 218 | 
            +
                
         | 
| 219 | 
            +
                def format_context(self, documents: List[Dict]) -> str:
         | 
| 220 | 
            +
                    """Format retrieved documents into context for the LLM"""
         | 
| 221 | 
            +
                    if not documents:
         | 
| 222 | 
            +
                        return "No relevant documents found."
         | 
| 223 | 
            +
                    
         | 
| 224 | 
            +
                    context_parts = []
         | 
| 225 | 
            +
                    for i, doc in enumerate(documents):
         | 
| 226 | 
            +
                        metadata = doc['metadata']
         | 
| 227 | 
            +
                        title = metadata.get('title', metadata.get('filename', 'Unknown document'))
         | 
| 228 | 
            +
                        
         | 
| 229 | 
            +
                        # For readability, limit length of context document
         | 
| 230 | 
            +
                        doc_text = doc['document']
         | 
| 231 | 
            +
                        if len(doc_text) > 10000:  # Limit long documents in context
         | 
| 232 | 
            +
                            doc_text = doc_text[:10000] + "... [Document truncated for context]"
         | 
| 233 | 
            +
                            
         | 
| 234 | 
            +
                        context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
         | 
| 235 | 
            +
                    
         | 
| 236 | 
            +
                    return "\n".join(context_parts)
         | 
| 237 | 
            +
                
         | 
| 238 | 
            +
                def generate_response_openai(self, query: str, context: str) -> str:
         | 
| 239 | 
            +
                    """Generate a response using OpenAI model with context"""
         | 
| 240 | 
            +
                    if not self.openai_client:
         | 
| 241 | 
            +
                        return "Error: OpenAI API key not configured. Please enter an API key in the settings tab."
         | 
| 242 | 
            +
                    
         | 
| 243 | 
            +
                    system_prompt = """
         | 
| 244 | 
            +
                    You are a helpful assistant that answers questions based on the context provided.
         | 
| 245 | 
            +
                    Use the information from the context to answer the user's question.
         | 
| 246 | 
            +
                    If the context doesn't contain the information needed, say so clearly.
         | 
| 247 | 
            +
                    Always cite the specific sections from the context that you used in your answer.
         | 
| 248 | 
            +
                    """
         | 
| 249 | 
            +
                    
         | 
| 250 | 
            +
                    try:
         | 
| 251 | 
            +
                        response = self.openai_client.chat.completions.create(
         | 
| 252 | 
            +
                            model="gpt-4o-mini",  # Use GPT-4o mini
         | 
| 253 | 
            +
                            messages=[
         | 
| 254 | 
            +
                                {"role": "system", "content": system_prompt},
         | 
| 255 | 
            +
                                {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
         | 
| 256 | 
            +
                            ],
         | 
| 257 | 
            +
                            temperature=0.3,  # Lower temperature for more factual responses
         | 
| 258 | 
            +
                            max_tokens=1000,
         | 
| 259 | 
            +
                        )
         | 
| 260 | 
            +
                        return response.choices[0].message.content
         | 
| 261 | 
            +
                    except Exception as e:
         | 
| 262 | 
            +
                        logger.error(f"Error generating response with OpenAI: {e}")
         | 
| 263 | 
            +
                        return f"Error generating response with OpenAI: {str(e)}"
         | 
| 264 | 
            +
                
         | 
| 265 | 
            +
                def generate_response_gemini(self, query: str, context: str) -> str:
         | 
| 266 | 
            +
                    """Generate a response using Gemini with context"""
         | 
| 267 | 
            +
                    if not self.gemini_configured:
         | 
| 268 | 
            +
                        return "Error: Google AI API key not configured. Please enter an API key in the settings tab."
         | 
| 269 | 
            +
                    
         | 
| 270 | 
            +
                    prompt = f"""
         | 
| 271 | 
            +
                    You are a helpful assistant that answers questions based on the context provided.
         | 
| 272 | 
            +
                    Use the information from the context to answer the user's question.
         | 
| 273 | 
            +
                    If the context doesn't contain the information needed, say so clearly.
         | 
| 274 | 
            +
                    Always cite the specific sections from the context that you used in your answer.
         | 
| 275 | 
            +
                    
         | 
| 276 | 
            +
                    Context:
         | 
| 277 | 
            +
                    {context}
         | 
| 278 | 
            +
                    
         | 
| 279 | 
            +
                    Question: {query}
         | 
| 280 | 
            +
                    """
         | 
| 281 | 
            +
                    
         | 
| 282 | 
            +
                    try:
         | 
| 283 | 
            +
                        model = genai.GenerativeModel('gemini-1.5-flash')
         | 
| 284 | 
            +
                        response = model.generate_content(prompt)
         | 
| 285 | 
            +
                        return response.text
         | 
| 286 | 
            +
                    except Exception as e:
         | 
| 287 | 
            +
                        logger.error(f"Error generating response with Gemini: {e}")
         | 
| 288 | 
            +
                        return f"Error generating response with Gemini: {str(e)}"
         | 
| 289 | 
            +
                
         | 
| 290 | 
            +
                def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
         | 
| 291 | 
            +
                    """Retrieve relevant documents and generate a response using the specified model"""
         | 
| 292 | 
            +
                    # Query vector store
         | 
| 293 | 
            +
                    documents = self.vector_store.query(query, n_results=n_results)
         | 
| 294 | 
            +
                    
         | 
| 295 | 
            +
                    if not documents:
         | 
| 296 | 
            +
                        return "No relevant documents found to answer your question."
         | 
| 297 | 
            +
                    
         | 
| 298 | 
            +
                    # Format context
         | 
| 299 | 
            +
                    context = self.format_context(documents)
         | 
| 300 | 
            +
                    
         | 
| 301 | 
            +
                    # Generate response with the appropriate model
         | 
| 302 | 
            +
                    if model == "openai":
         | 
| 303 | 
            +
                        return self.generate_response_openai(query, context)
         | 
| 304 | 
            +
                    elif model == "gemini":
         | 
| 305 | 
            +
                        return self.generate_response_gemini(query, context)
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        return f"Unknown model: {model}"
         | 
| 308 |  | 
| 309 | 
            +
            def rag_chat(query, n_results, model_choice, rag_system):
         | 
| 310 | 
            +
                """Function to handle RAG chat queries"""
         | 
| 311 | 
            +
                return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 312 |  | 
| 313 | 
            +
            def simple_query(query, n_results, vector_store):
         | 
| 314 | 
            +
                """Function to handle simple vector store queries"""
         | 
| 315 | 
            +
                results = vector_store.query(query, n_results=int(n_results))
         | 
| 316 | 
            +
                
         | 
| 317 | 
            +
                # Format results for display
         | 
| 318 | 
            +
                formatted = []
         | 
| 319 | 
            +
                for i, res in enumerate(results):
         | 
| 320 | 
            +
                    metadata = res['metadata']
         | 
| 321 | 
            +
                    title = metadata.get('title', metadata.get('filename', 'Unknown'))
         | 
| 322 | 
            +
                    # Limit preview text for display
         | 
| 323 | 
            +
                    preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document']
         | 
| 324 | 
            +
                    formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n"
         | 
| 325 | 
            +
                                   f"**Source:** {title}\n\n"
         | 
| 326 | 
            +
                                   f"**Content:**\n{preview}\n\n"
         | 
| 327 | 
            +
                                   f"---\n")
         | 
| 328 | 
            +
                
         | 
| 329 | 
            +
                return "\n".join(formatted) if formatted else "No results found."
         | 
| 330 |  | 
| 331 | 
            +
            def get_db_stats(vector_store):
         | 
| 332 | 
            +
                """Function to get vector store statistics"""
         | 
| 333 | 
            +
                stats = vector_store.get_statistics()
         | 
| 334 | 
            +
                return (f"Total documents: {stats.get('total_documents', 0)}\n"
         | 
| 335 | 
            +
                       f"Unique files: {stats.get('unique_files', 0)}")
         | 
| 336 |  | 
| 337 | 
            +
            def update_api_keys(openai_key, gemini_key, rag_system):
         | 
| 338 | 
            +
                """Update API keys for the RAG system"""
         | 
| 339 | 
            +
                success_msg = []
         | 
| 340 | 
            +
                
         | 
| 341 | 
            +
                if openai_key:
         | 
| 342 | 
            +
                    if rag_system.setup_openai(openai_key):
         | 
| 343 | 
            +
                        success_msg.append("✅ OpenAI API key configured successfully")
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        success_msg.append("❌ Failed to configure OpenAI API key")
         | 
| 346 | 
            +
                
         | 
| 347 | 
            +
                if gemini_key:
         | 
| 348 | 
            +
                    if rag_system.setup_gemini(gemini_key):
         | 
| 349 | 
            +
                        success_msg.append("✅ Google AI API key configured successfully")
         | 
| 350 | 
            +
                    else:
         | 
| 351 | 
            +
                        success_msg.append("❌ Failed to configure Google AI API key")
         | 
| 352 | 
            +
                
         | 
| 353 | 
            +
                if not success_msg:
         | 
| 354 | 
            +
                    return "Please enter at least one API key"
         | 
| 355 | 
            +
                
         | 
| 356 | 
            +
                return "\n".join(success_msg)
         | 
| 357 |  | 
| 358 | 
            +
            # Main function to run the application
         | 
| 359 | 
            +
            def main():
         | 
| 360 | 
            +
                # Set up paths for existing Chroma database
         | 
| 361 | 
            +
                chroma_dir = Path("./chroma_data")
         | 
| 362 | 
            +
                
         | 
| 363 | 
            +
                # Initialize the system
         | 
| 364 | 
            +
                config = Config(
         | 
| 365 | 
            +
                    local_dir=str(chroma_dir),
         | 
| 366 | 
            +
                    collection_name="markdown_docs"
         | 
| 367 | 
            +
                )
         | 
| 368 | 
            +
                
         | 
| 369 | 
            +
                # Initialize vector store manager with existing collection
         | 
| 370 | 
            +
                vector_store = VectorStoreManager(config)
         | 
| 371 | 
            +
                
         | 
| 372 | 
            +
                # Initialize RAG system without API keys initially
         | 
| 373 | 
            +
                rag_system = RAGSystem(vector_store)
         | 
| 374 | 
            +
                
         | 
| 375 | 
            +
                # Define Gradio app
         | 
| 376 | 
            +
                def rag_chat_wrapper(query, n_results, model_choice):
         | 
| 377 | 
            +
                    return rag_chat(query, n_results, model_choice, rag_system)
         | 
| 378 | 
            +
                
         | 
| 379 | 
            +
                def simple_query_wrapper(query, n_results):
         | 
| 380 | 
            +
                    return simple_query(query, n_results, vector_store)
         | 
| 381 | 
            +
                
         | 
| 382 | 
            +
                def update_api_keys_wrapper(openai_key, gemini_key):
         | 
| 383 | 
            +
                    return update_api_keys(openai_key, gemini_key, rag_system)
         | 
| 384 | 
            +
                
         | 
| 385 | 
            +
                # Create the Gradio interface
         | 
| 386 | 
            +
                with gr.Blocks(title="Markdown RAG System") as app:
         | 
| 387 | 
            +
                    gr.Markdown("# RAG System with Multiple LLM Providers")
         | 
| 388 | 
            +
                    
         | 
| 389 | 
            +
                    with gr.Tab("Chat with Documents"):
         | 
| 390 | 
            +
                        with gr.Row():
         | 
| 391 | 
            +
                            with gr.Column(scale=3):
         | 
| 392 | 
            +
                                query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...")
         | 
| 393 | 
            +
                                num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve")
         | 
| 394 | 
            +
                                model_choice = gr.Radio(
         | 
| 395 | 
            +
                                    choices=["openai", "gemini"], 
         | 
| 396 | 
            +
                                    value="openai", 
         | 
| 397 | 
            +
                                    label="Choose LLM Provider",
         | 
| 398 | 
            +
                                    info="Select which model to use for generating answers"
         | 
| 399 | 
            +
                                )
         | 
| 400 | 
            +
                                query_button = gr.Button("Ask", variant="primary")
         | 
| 401 | 
            +
                            
         | 
| 402 | 
            +
                            with gr.Column(scale=7):
         | 
| 403 | 
            +
                                response_output = gr.Markdown(label="Response")
         | 
| 404 | 
            +
                        
         | 
| 405 | 
            +
                        # Database stats
         | 
| 406 | 
            +
                        stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store))
         | 
| 407 | 
            +
                        refresh_button = gr.Button("Refresh Statistics")
         | 
| 408 | 
            +
                    
         | 
| 409 | 
            +
                    with gr.Tab("Document Search"):
         | 
| 410 | 
            +
                        search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...")
         | 
| 411 | 
            +
                        search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
         | 
| 412 | 
            +
                        search_button = gr.Button("Search", variant="primary")
         | 
| 413 | 
            +
                        search_output = gr.Markdown(label="Search Results")
         | 
| 414 | 
            +
                    
         | 
| 415 | 
            +
                    with gr.Tab("Settings"):
         | 
| 416 | 
            +
                        gr.Markdown("""
         | 
| 417 | 
            +
                        ## API Keys Configuration
         | 
| 418 | 
            +
                        
         | 
| 419 | 
            +
                        This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses.
         | 
| 420 | 
            +
                        You need to provide at least one API key to use the chat functionality.
         | 
| 421 | 
            +
                        """)
         | 
| 422 | 
            +
                        
         | 
| 423 | 
            +
                        openai_key_input = gr.Textbox(
         | 
| 424 | 
            +
                            label="OpenAI API Key",
         | 
| 425 | 
            +
                            placeholder="Enter your OpenAI API key here...",
         | 
| 426 | 
            +
                            type="password"
         | 
| 427 | 
            +
                        )
         | 
| 428 | 
            +
                        
         | 
| 429 | 
            +
                        gemini_key_input = gr.Textbox(
         | 
| 430 | 
            +
                            label="Google AI API Key",
         | 
| 431 | 
            +
                            placeholder="Enter your Google AI API key here...",
         | 
| 432 | 
            +
                            type="password"
         | 
| 433 | 
            +
                        )
         | 
| 434 | 
            +
                        
         | 
| 435 | 
            +
                        save_keys_button = gr.Button("Save API Keys", variant="primary")
         | 
| 436 | 
            +
                        api_status = gr.Markdown("")
         | 
| 437 | 
            +
                    
         | 
| 438 | 
            +
                    # Set up events
         | 
| 439 | 
            +
                    query_button.click(
         | 
| 440 | 
            +
                        fn=rag_chat_wrapper,
         | 
| 441 | 
            +
                        inputs=[query_input, num_results, model_choice],
         | 
| 442 | 
            +
                        outputs=response_output
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
                    
         | 
| 445 | 
            +
                    refresh_button.click(
         | 
| 446 | 
            +
                        fn=lambda: get_db_stats(vector_store),
         | 
| 447 | 
            +
                        inputs=None,
         | 
| 448 | 
            +
                        outputs=stats_display
         | 
| 449 | 
            +
                    )
         | 
| 450 | 
            +
                    
         | 
| 451 | 
            +
                    search_button.click(
         | 
| 452 | 
            +
                        fn=simple_query_wrapper,
         | 
| 453 | 
            +
                        inputs=[search_input, search_num],
         | 
| 454 | 
            +
                        outputs=search_output
         | 
| 455 | 
            +
                    )
         | 
| 456 | 
            +
                    
         | 
| 457 | 
            +
                    save_keys_button.click(
         | 
| 458 | 
            +
                        fn=update_api_keys_wrapper,
         | 
| 459 | 
            +
                        inputs=[openai_key_input, gemini_key_input],
         | 
| 460 | 
            +
                        outputs=api_status
         | 
| 461 | 
            +
                    )
         | 
| 462 | 
            +
                
         | 
| 463 | 
            +
                # Launch the interface
         | 
| 464 | 
            +
                app.launch()
         | 
| 465 |  | 
| 466 | 
             
            if __name__ == "__main__":
         | 
| 467 | 
            +
                main()
         | 
