jzou19950715 commited on
Commit
894a951
·
verified ·
1 Parent(s): dcf7268

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -938
app.py CHANGED
@@ -4,193 +4,48 @@ import logging
4
  from pathlib import Path
5
  import json
6
  from datetime import datetime
7
- from typing import List, Dict, Any, Optional, Tuple, Union
8
- import traceback
9
 
10
- # Configure detailed logging with file output
11
- LOG_DIR = "logs"
12
- os.makedirs(LOG_DIR, exist_ok=True)
13
- log_file = os.path.join(LOG_DIR, f"rag_system_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
14
 
15
- # Set up root logger with both file and console handlers
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
- handlers=[
20
- logging.FileHandler(log_file),
21
- logging.StreamHandler(sys.stdout)
22
- ]
23
- )
24
- logger = logging.getLogger("rag_system")
25
- logger.info(f"Starting RAG system. Log file: {log_file}")
26
-
27
- # Importing necessary libraries with error handling
28
- try:
29
- import torch
30
- import numpy as np
31
- from sentence_transformers import SentenceTransformer
32
- import chromadb
33
- from chromadb.utils import embedding_functions
34
- import gradio as gr
35
- from openai import OpenAI
36
- import google.generativeai as genai
37
- logger.info("All required libraries successfully imported")
38
- except ImportError as e:
39
- logger.critical(f"Failed to import required libraries: {e}")
40
- print(f"ERROR: Missing required libraries. Please install with: pip install -r requirements.txt")
41
- print(f"Specific error: {e}")
42
- sys.exit(1)
43
-
44
- # Version info for tracking
45
- VERSION = "1.1.0"
46
- logger.info(f"RAG System Version: {VERSION}")
47
-
48
- # Custom CSS for better UI
49
- custom_css = """
50
- .gradio-container {
51
- max-width: 1200px;
52
- margin: auto;
53
- }
54
- .gr-prose h1 {
55
- font-size: 2.5rem;
56
- margin-bottom: 1rem;
57
- color: #1a5276;
58
- }
59
- .gr-prose h3 {
60
- font-size: 1.25rem;
61
- font-weight: 600;
62
- margin-top: 1rem;
63
- margin-bottom: 0.5rem;
64
- color: #2874a6;
65
- }
66
- .container {
67
- margin: 0 auto;
68
- padding: 2rem;
69
- }
70
- .gr-box {
71
- border-radius: 8px;
72
- box-shadow: 0 1px 3px rgba(0,0,0,0.12), 0 1px 2px rgba(0,0,0,0.24);
73
- padding: 1rem;
74
- margin-bottom: 1rem;
75
- background-color: #f9f9f9;
76
- }
77
- .footer {
78
- text-align: center;
79
- font-size: 0.8rem;
80
- color: #666;
81
- margin-top: 2rem;
82
- }
83
- """
84
 
 
85
  class Config:
86
- """
87
- Configuration for vector store and RAG system.
88
-
89
- This class centralizes all configuration parameters for the application,
90
- making it easier to modify settings and ensure consistency.
91
-
92
- Attributes:
93
- local_dir (str): Directory for ChromaDB persistence
94
- embedding_model (str): Name of the embedding model to use
95
- collection_name (str): Name of the ChromaDB collection
96
- default_top_k (int): Default number of results to return
97
- openai_model (str): Default OpenAI model to use
98
- gemini_model (str): Default Gemini model to use
99
- temperature (float): Temperature setting for LLM generation
100
- max_tokens (int): Maximum tokens for LLM response
101
- system_name (str): Name of the system for UI
102
- context_limit (int): Maximum characters to include in context
103
- """
104
-
105
  def __init__(self,
106
- local_dir: str = "./chroma_db",
107
  embedding_model: str = "all-MiniLM-L6-v2",
108
- collection_name: str = "markdown_docs",
109
- default_top_k: int = 8, # Increased from 5 to 8 for more context
110
- openai_model: str = "gpt-4o-mini",
111
- gemini_model: str = "gemini-1.5-flash",
112
- temperature: float = 0.3,
113
- max_tokens: int = 2000, # Increased from 1000 to 2000 for more comprehensive responses
114
- system_name: str = "Document Knowledge Assistant",
115
- context_limit: int = 16000): # Increased context limit for more comprehensive context
116
  self.local_dir = local_dir
117
  self.embedding_model = embedding_model
118
  self.collection_name = collection_name
119
- self.default_top_k = default_top_k
120
- self.openai_model = openai_model
121
- self.gemini_model = gemini_model
122
- self.temperature = temperature
123
- self.max_tokens = max_tokens
124
- self.system_name = system_name
125
- self.context_limit = context_limit
126
-
127
- # Create local directory if it doesn't exist
128
- os.makedirs(local_dir, exist_ok=True)
129
-
130
- logger.info(f"Initialized configuration: {self.__dict__}")
131
-
132
- def to_dict(self) -> Dict[str, Any]:
133
- """Convert configuration to dictionary for serialization"""
134
- return self.__dict__
135
-
136
- @classmethod
137
- def from_file(cls, config_path: str) -> 'Config':
138
- """Load configuration from JSON file"""
139
- try:
140
- with open(config_path, 'r') as f:
141
- config_dict = json.load(f)
142
- logger.info(f"Loaded configuration from {config_path}")
143
- return cls(**config_dict)
144
- except Exception as e:
145
- logger.error(f"Failed to load configuration from {config_path}: {e}")
146
- logger.info("Using default configuration")
147
- return cls()
148
-
149
- def save_to_file(self, config_path: str) -> bool:
150
- """Save configuration to JSON file"""
151
- try:
152
- with open(config_path, 'w') as f:
153
- json.dump(self.to_dict(), f, indent=2)
154
- logger.info(f"Saved configuration to {config_path}")
155
- return True
156
- except Exception as e:
157
- logger.error(f"Failed to save configuration to {config_path}: {e}")
158
- return False
159
 
 
160
  class EmbeddingEngine:
161
- """
162
- Handle embeddings with a lightweight model.
163
-
164
- This class manages the embedding model used to convert text to vector
165
- representations for semantic search.
166
-
167
- Attributes:
168
- model (SentenceTransformer): The loaded embedding model
169
- model_name (str): Name of the successfully loaded model
170
- vector_size (int): Dimension of the embedding vectors
171
- device (str): Device used for inference ('cuda' or 'cpu')
172
- """
173
 
174
  def __init__(self, model_name="all-MiniLM-L6-v2"):
175
- """
176
- Initialize the embedding engine with the specified model.
177
-
178
- Args:
179
- model_name (str): Name of the embedding model to load
180
-
181
- Raises:
182
- SystemExit: If no embedding model could be loaded
183
- """
184
  # Use GPU if available
185
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
186
- logger.info(f"Using device for embeddings: {self.device}")
187
 
188
  # Try multiple model options in order of preference
189
  model_options = [
190
  model_name,
191
- "all-MiniLM-L6-v2", # Good balance of speed and quality
192
- "paraphrase-MiniLM-L3-v2", # Faster but less accurate
193
- "all-mpnet-base-v2" # Higher quality but larger model
194
  ]
195
 
196
  self.model = None
@@ -198,99 +53,47 @@ class EmbeddingEngine:
198
  # Try each model in order until one works
199
  for model_option in model_options:
200
  try:
201
- logger.info(f"Attempting to load embedding model: {model_option}")
202
  self.model = SentenceTransformer(model_option)
203
 
204
  # Move model to device
205
  self.model.to(self.device)
206
 
207
- logger.info(f"Successfully loaded embedding model: {model_option}")
208
  self.model_name = model_option
209
  self.vector_size = self.model.get_sentence_embedding_dimension()
210
- logger.info(f"Embedding vector size: {self.vector_size}")
211
  break
212
 
213
  except Exception as e:
214
- logger.warning(f"Failed to load embedding model {model_option}: {str(e)}")
215
 
216
  if self.model is None:
217
- error_msg = "Failed to load any embedding model. Please check your internet connection or install models locally."
218
- logger.critical(error_msg)
219
- raise SystemExit(error_msg)
220
-
221
- def embed(self, texts: List[str]) -> np.ndarray:
222
- """
223
- Generate embeddings for a list of texts.
224
-
225
- Args:
226
- texts (List[str]): List of texts to embed
227
-
228
- Returns:
229
- np.ndarray: Array of embeddings
230
-
231
- Raises:
232
- ValueError: If the input is invalid
233
- RuntimeError: If embedding fails
234
- """
235
- if not texts:
236
- raise ValueError("Cannot embed empty list of texts")
237
-
238
- try:
239
- embeddings = self.model.encode(texts, convert_to_numpy=True)
240
- return embeddings
241
- except Exception as e:
242
- logger.error(f"Error generating embeddings: {e}")
243
- raise RuntimeError(f"Failed to generate embeddings: {e}")
244
 
245
  class VectorStoreManager:
246
- """
247
- Manage Chroma vector store operations - upload, query, etc.
248
-
249
- This class provides an interface to the ChromaDB vector database,
250
- handling document storage, retrieval, and management.
251
-
252
- Attributes:
253
- config (Config): Configuration parameters
254
- client (chromadb.PersistentClient): ChromaDB client
255
- collection (chromadb.Collection): The active ChromaDB collection
256
- embedding_engine (EmbeddingEngine): Engine for generating embeddings
257
- """
258
 
259
  def __init__(self, config: Config):
260
- """
261
- Initialize the vector store manager.
262
-
263
- Args:
264
- config (Config): Configuration parameters
265
-
266
- Raises:
267
- SystemExit: If the vector store cannot be initialized
268
- """
269
  self.config = config
270
 
271
  # Initialize Chroma client (local persistence)
272
  logger.info(f"Initializing Chroma at {config.local_dir}")
273
- try:
274
- self.client = chromadb.PersistentClient(path=config.local_dir)
275
- logger.info("ChromaDB client initialized successfully")
276
- except Exception as e:
277
- error_msg = f"Failed to initialize ChromaDB client: {e}"
278
- logger.critical(error_msg)
279
- raise SystemExit(error_msg)
280
 
281
  # Get or create collection
282
  try:
283
  # Initialize embedding model
284
  logger.info("Loading embedding model...")
285
  self.embedding_engine = EmbeddingEngine(config.embedding_model)
286
- logger.info(f"Using embedding model: {self.embedding_engine.model_name}")
287
 
288
  # Create embedding function
289
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
290
  model_name=self.embedding_engine.model_name
291
  )
292
 
293
- # Try to get existing collection or create a new one
294
  try:
295
  self.collection = self.client.get_collection(
296
  name=config.collection_name,
@@ -298,7 +101,7 @@ class VectorStoreManager:
298
  )
299
  logger.info(f"Using existing collection: {config.collection_name}")
300
  except Exception as e:
301
- logger.warning(f"Error getting collection: {e}")
302
  # Attempt to get a list of available collections
303
  collections = self.client.list_collections()
304
  if collections:
@@ -319,28 +122,14 @@ class VectorStoreManager:
319
  logger.info(f"Created new collection: {config.collection_name}")
320
 
321
  except Exception as e:
322
- error_msg = f"Error initializing Chroma collection: {e}"
323
- logger.critical(error_msg)
324
- raise SystemExit(error_msg)
325
 
326
  def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
327
  """
328
- Query the vector store with a text query.
329
-
330
- Args:
331
- query_text (str): The query text
332
- n_results (int): Number of results to return
333
-
334
- Returns:
335
- List[Dict]: List of results with document text, metadata, and similarity score
336
  """
337
- if not query_text.strip():
338
- logger.warning("Empty query received")
339
- return []
340
-
341
  try:
342
- logger.info(f"Querying vector store with: '{query_text[:50]}...' (top {n_results})")
343
-
344
  # Query the collection
345
  search_results = self.collection.query(
346
  query_texts=[query_text],
@@ -354,106 +143,26 @@ class VectorStoreManager:
354
  for i in range(len(search_results["documents"][0])):
355
  results.append({
356
  'document': search_results["documents"][0][i],
357
- 'metadata': search_results["metadatas"][0][i] if search_results["metadatas"] else {},
358
- 'score': 1.0 - search_results["distances"][0][i], # Convert distance to similarity
359
- 'distance': search_results["distances"][0][i]
360
  })
361
-
362
- logger.info(f"Found {len(results)} results for query")
363
- else:
364
- logger.info("No results found for query")
365
 
366
  return results
367
  except Exception as e:
368
  logger.error(f"Error querying collection: {e}")
369
- logger.debug(traceback.format_exc())
370
  return []
371
 
372
- def add_document(self,
373
- document: str,
374
- doc_id: str,
375
- metadata: Dict[str, Any]) -> bool:
376
- """
377
- Add a document to the vector store.
378
-
379
- Args:
380
- document (str): The document text
381
- doc_id (str): Unique identifier for the document
382
- metadata (Dict[str, Any]): Metadata about the document
383
-
384
- Returns:
385
- bool: True if successful, False otherwise
386
- """
387
- try:
388
- logger.info(f"Adding document '{doc_id}' to vector store")
389
-
390
- # Add the document to the collection
391
- self.collection.add(
392
- documents=[document],
393
- ids=[doc_id],
394
- metadatas=[metadata]
395
- )
396
-
397
- logger.info(f"Successfully added document '{doc_id}'")
398
- return True
399
- except Exception as e:
400
- logger.error(f"Error adding document to collection: {e}")
401
- return False
402
-
403
- def delete_document(self, doc_id: str) -> bool:
404
- """
405
- Delete a document from the vector store.
406
-
407
- Args:
408
- doc_id (str): ID of the document to delete
409
-
410
- Returns:
411
- bool: True if successful, False otherwise
412
- """
413
- try:
414
- logger.info(f"Deleting document '{doc_id}' from vector store")
415
- self.collection.delete(ids=[doc_id])
416
- logger.info(f"Successfully deleted document '{doc_id}'")
417
- return True
418
- except Exception as e:
419
- logger.error(f"Error deleting document from collection: {e}")
420
- return False
421
-
422
  def get_statistics(self) -> Dict[str, Any]:
423
- """
424
- Get statistics about the vector store.
425
-
426
- Returns:
427
- Dict[str, Any]: Statistics about the vector store
428
- """
429
- stats = {
430
- 'collection_name': self.config.collection_name,
431
- 'embedding_model': self.embedding_engine.model_name,
432
- 'embedding_dimensions': self.embedding_engine.vector_size,
433
- 'device': self.embedding_engine.device
434
- }
435
 
436
  try:
437
  # Get collection count
438
- collection_count = self.collection.count()
439
- stats['total_documents'] = collection_count
440
 
441
- # Get unique metadata values
442
- if collection_count > 0:
443
- try:
444
- # Get a sample of document metadata
445
- sample_results = self.collection.get(limit=min(collection_count, 100))
446
- if sample_results and 'metadatas' in sample_results and sample_results['metadatas']:
447
- # Count unique files if filename exists in metadata
448
- filenames = set()
449
- for metadata in sample_results['metadatas']:
450
- if 'filename' in metadata:
451
- filenames.add(metadata['filename'])
452
- stats['unique_files'] = len(filenames)
453
- except Exception as e:
454
- logger.warning(f"Error getting metadata statistics: {e}")
455
-
456
- logger.info(f"Vector store statistics: {stats}")
457
  except Exception as e:
458
  logger.error(f"Error getting statistics: {e}")
459
  stats['error'] = str(e)
@@ -461,635 +170,274 @@ class VectorStoreManager:
461
  return stats
462
 
463
  class RAGSystem:
464
- """
465
- Retrieval-Augmented Generation with multiple LLM providers.
466
-
467
- This class handles the RAG workflow: retrieval of relevant documents,
468
- formatting context, and generating responses with different LLM providers.
469
-
470
- Attributes:
471
- vector_store (VectorStoreManager): Manager for vector store operations
472
- openai_client (Optional[OpenAI]): OpenAI client
473
- gemini_configured (bool): Whether Gemini API is configured
474
- config (Config): Configuration parameters
475
- """
476
 
477
- def __init__(self, vector_store: VectorStoreManager, config: Config):
478
- """
479
- Initialize the RAG system.
480
-
481
- Args:
482
- vector_store (VectorStoreManager): Vector store manager
483
- config (Config): Configuration parameters
484
- """
485
  self.vector_store = vector_store
486
- self.config = config
487
  self.openai_client = None
488
  self.gemini_configured = False
489
-
490
- logger.info("Initialized RAG system")
491
 
492
- def setup_openai(self, api_key: str) -> bool:
493
- """
494
- Set up OpenAI client with API key.
495
-
496
- Args:
497
- api_key (str): OpenAI API key
498
-
499
- Returns:
500
- bool: True if successful, False otherwise
501
- """
502
- if not api_key.strip():
503
- logger.warning("Empty OpenAI API key provided")
504
- return False
505
-
506
  try:
507
- logger.info("Setting up OpenAI client")
508
  self.openai_client = OpenAI(api_key=api_key)
509
- # Test the API key with a simple request
510
- response = self.openai_client.chat.completions.create(
511
- model=self.config.openai_model,
512
- messages=[
513
- {"role": "system", "content": "You are a helpful assistant."},
514
- {"role": "user", "content": "Test connection"}
515
- ],
516
- max_tokens=10
517
- )
518
- logger.info("OpenAI client configured successfully")
519
  return True
520
  except Exception as e:
521
  logger.error(f"Error initializing OpenAI client: {e}")
522
- self.openai_client = None
523
  return False
524
 
525
- def setup_gemini(self, api_key: str) -> bool:
526
- """
527
- Set up Gemini with API key.
528
-
529
- Args:
530
- api_key (str): Google AI API key
531
-
532
- Returns:
533
- bool: True if successful, False otherwise
534
- """
535
- if not api_key.strip():
536
- logger.warning("Empty Gemini API key provided")
537
- return False
538
-
539
  try:
540
- logger.info("Setting up Gemini client")
541
  genai.configure(api_key=api_key)
542
-
543
- # Test the API key with a simple request
544
- model = genai.GenerativeModel(self.config.gemini_model)
545
- response = model.generate_content("Test connection")
546
-
547
  self.gemini_configured = True
548
- logger.info("Gemini client configured successfully")
549
  return True
550
  except Exception as e:
551
  logger.error(f"Error configuring Gemini: {e}")
552
- self.gemini_configured = False
553
  return False
554
 
555
  def format_context(self, documents: List[Dict]) -> str:
556
- """
557
- Format retrieved documents into context for the LLM.
558
-
559
- Args:
560
- documents (List[Dict]): List of retrieved documents
561
-
562
- Returns:
563
- str: Formatted context for the LLM
564
- """
565
  if not documents:
566
- logger.warning("No documents provided for context formatting")
567
  return "No relevant documents found."
568
 
569
- logger.info(f"Formatting {len(documents)} documents for context")
570
  context_parts = []
571
-
572
  for i, doc in enumerate(documents):
573
  metadata = doc['metadata']
574
- # Extract document metadata in a robust way
575
  title = metadata.get('title', metadata.get('filename', 'Unknown document'))
576
 
577
- # Format header with just essential metadata for cleaner context
578
- header = f"Document {i+1} - {title}"
579
-
580
- # For readability, limit length of context document (using config value)
581
  doc_text = doc['document']
582
- if len(doc_text) > (self.config.context_limit // len(documents)):
583
- # Divide context limit among the documents
584
- max_length = self.config.context_limit // len(documents)
585
- doc_text = doc_text[:max_length] + "... [Document truncated for brevity]"
586
 
587
- context_parts.append(f"{header}:\n{doc_text}\n")
588
 
589
- full_context = "\n".join(context_parts)
590
- logger.info(f"Created context with {len(full_context)} characters")
591
-
592
- return full_context
593
 
594
  def generate_response_openai(self, query: str, context: str) -> str:
595
- """
596
- Generate a response using OpenAI model with context.
597
-
598
- Args:
599
- query (str): User query
600
- context (str): Formatted document context
601
-
602
- Returns:
603
- str: Generated response
604
- """
605
  if not self.openai_client:
606
- logger.warning("OpenAI API key not configured for response generation")
607
- return "Please configure an OpenAI API key to use this feature. Enter your API key in the field and click 'Save API Key'."
608
 
609
- # Improved system prompt for better, more comprehensive responses
610
  system_prompt = """
611
- You are an exceptionally helpful, clear, and friendly AI research assistant. Your goal is to provide comprehensive, well-structured, and insightful answers based on the provided document context.
612
-
613
- Guidelines for your response:
614
-
615
- 1. USE ONLY the information contained in the provided context documents to form your answer. If the context doesn't contain enough information to provide a complete answer, acknowledge this limitation clearly.
616
-
617
- 2. Always provide well-structured, detailed responses between 300-500 words that thoroughly address the user's question.
618
-
619
- 3. Format your response with clear headings, bullet points, or numbered lists when appropriate to enhance readability.
620
-
621
- 4. Cite your sources by referring to the document numbers (e.g., "According to Document 1...") to support your claims.
622
-
623
- 5. Use a friendly, conversational, and supportive tone that makes complex information accessible.
624
-
625
- 6. If different documents offer conflicting information, acknowledge these differences and present both perspectives without bias.
626
-
627
- 7. When appropriate, organize information into logical categories or chronological order to improve clarity.
628
-
629
- 8. Use examples from the documents to illustrate key points when available.
630
-
631
- 9. Conclude with a brief summary of the main points if the answer is complex.
632
-
633
- 10. Remember to stay focused on the user's specific question while providing sufficient context for complete understanding.
634
  """
635
 
636
  try:
637
- logger.info(f"Generating response with OpenAI ({self.config.openai_model})")
638
-
639
- start_time = datetime.now()
640
  response = self.openai_client.chat.completions.create(
641
- model=self.config.openai_model,
642
  messages=[
643
  {"role": "system", "content": system_prompt},
644
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
645
  ],
646
- temperature=self.config.temperature,
647
- max_tokens=self.config.max_tokens,
648
  )
649
-
650
- generation_time = (datetime.now() - start_time).total_seconds()
651
- response_text = response.choices[0].message.content
652
-
653
- logger.info(f"Generated response with OpenAI in {generation_time:.2f} seconds")
654
- return response_text
655
  except Exception as e:
656
- error_msg = f"Error generating response with OpenAI: {str(e)}"
657
- logger.error(error_msg)
658
- return f"I encountered an error while generating your response. Please try again or check your API key. Error details: {str(e)}"
659
 
660
  def generate_response_gemini(self, query: str, context: str) -> str:
661
- """
662
- Generate a response using Gemini with context.
663
-
664
- Args:
665
- query (str): User query
666
- context (str): Formatted document context
667
-
668
- Returns:
669
- str: Generated response
670
- """
671
  if not self.gemini_configured:
672
- logger.warning("Gemini API key not configured for response generation")
673
- return "Please configure a Google AI API key to use this feature. Enter your API key in the field and click 'Save API Key'."
674
 
675
- # Improved Gemini prompt for more comprehensive and user-friendly responses
676
  prompt = f"""
677
- You are a knowledgeable and friendly research assistant who excels at providing clear, comprehensive, and well-structured responses. Your goal is to help users understand complex information from documents in an accessible way.
678
-
679
- **Guidelines for Your Response:**
680
 
681
- - Create a detailed, well-organized response of approximately 300-500 words that thoroughly addresses the user's question.
682
- - Use ONLY information from the provided context documents.
683
- - Structure your answer with clear paragraphs, and use headings, bullet points, or numbered lists when appropriate.
684
- - Maintain a friendly, conversational tone that makes information accessible and engaging.
685
- - When citing information, reference specific documents by number (e.g., "As mentioned in Document 2...").
686
- - If the context doesn't contain enough information for a complete answer, acknowledge this limitation while providing what you can from the available context.
687
- - If documents contain conflicting information, present both perspectives fairly.
688
- - Conclude with a brief summary if the topic is complex.
689
-
690
- **Context Documents:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
  {context}
692
 
693
- **User's Question:**
694
- {query}
695
-
696
- **Your Response:**
697
  """
698
-
699
  try:
700
- logger.info(f"Generating response with Gemini ({self.config.gemini_model})")
701
-
702
- start_time = datetime.now()
703
- model = genai.GenerativeModel(self.config.gemini_model)
704
-
705
- generation_config = {
706
- "temperature": self.config.temperature,
707
- "max_output_tokens": self.config.max_tokens,
708
- "top_p": 0.9,
709
- "top_k": 40
710
- }
711
-
712
- response = model.generate_content(
713
- prompt,
714
- generation_config=generation_config
715
- )
716
-
717
- generation_time = (datetime.now() - start_time).total_seconds()
718
- response_text = response.text
719
-
720
- logger.info(f"Generated response with Gemini in {generation_time:.2f} seconds")
721
- return response_text
722
  except Exception as e:
723
- error_msg = f"Error generating response with Gemini: {str(e)}"
724
- logger.error(error_msg)
725
- return f"I encountered an error while generating your response. Please try again or check your API key. Error details: {str(e)}"
726
 
727
- def query_and_generate(self,
728
- query: str,
729
- n_results: int = 5,
730
- model: str = "openai") -> Tuple[str, str]:
731
- """
732
- Retrieve relevant documents and generate a response using the specified model.
733
-
734
- Args:
735
- query (str): User query
736
- n_results (int): Number of documents to retrieve
737
- model (str): Model provider to use ('openai' or 'gemini')
738
-
739
- Returns:
740
- Tuple[str, str]: (Generated response, Search results)
741
- """
742
- if not query.strip():
743
- logger.warning("Empty query received")
744
- return "Please enter a question to get a response.", "No search performed."
745
-
746
- logger.info(f"Processing query: '{query[:50]}...' with {model} model")
747
-
748
  # Query vector store
749
  documents = self.vector_store.query(query, n_results=n_results)
750
 
751
- # Format search results (for logs and hidden UI component)
752
- # We'll format this in a way that's more useful for reference but not shown in UI
753
- formatted_results = []
754
- for i, res in enumerate(documents):
755
- metadata = res['metadata']
756
- title = metadata.get('title', metadata.get('filename', 'Unknown'))
757
- score = res['score']
758
-
759
- # Only include a very brief preview for reference
760
- preview = res['document'][:100] + '...' if len(res['document']) > 100 else res['document']
761
- formatted_results.append(f"Document {i+1}: {title} (Relevance: {score:.2f})")
762
-
763
- search_output_text = "\n".join(formatted_results) if formatted_results else "No relevant documents found."
764
-
765
  if not documents:
766
- logger.warning("No relevant documents found")
767
- return "I couldn't find relevant information in the knowledge base to answer your question. Could you try rephrasing your question or ask about a different topic?", search_output_text
768
 
769
  # Format context
770
  context = self.format_context(documents)
771
 
772
  # Generate response with the appropriate model
773
  if model == "openai":
774
- response = self.generate_response_openai(query, context)
775
  elif model == "gemini":
776
- response = self.generate_response_gemini(query, context)
777
  else:
778
- error_msg = f"Unknown model: {model}"
779
- logger.error(error_msg)
780
- return error_msg, search_output_text
781
-
782
- return response, search_output_text
783
-
784
- def get_db_stats(vector_store: VectorStoreManager) -> str:
785
- """
786
- Function to get vector store statistics.
787
-
788
- Args:
789
- vector_store (VectorStoreManager): Vector store manager
790
-
791
- Returns:
792
- str: Formatted statistics string
793
- """
794
- try:
795
- stats = vector_store.get_statistics()
796
- total_docs = stats.get('total_documents', 0)
797
-
798
- stats_text = f"Documents in knowledge base: {total_docs}"
799
- return stats_text
800
- except Exception as e:
801
- logger.error(f"Error getting statistics: {e}")
802
- return "Error getting database statistics"
803
-
804
- # Helper function for loading documents (can be expanded in future versions)
805
- def load_document(file_path: str, chunk_size: int = 2000, chunk_overlap: int = 200) -> bool:
806
- """
807
- Load a document into the vector store.
808
-
809
- Args:
810
- file_path (str): Path to the document
811
- chunk_size (int): Size of chunks to split the document into
812
- chunk_overlap (int): Overlap between chunks
813
-
814
- Returns:
815
- bool: True if successful, False otherwise
816
- """
817
- try:
818
- try:
819
- logger.info(f"Loading document: {file_path}")
820
-
821
- # Initialize components
822
- config = Config()
823
- vector_store = VectorStoreManager(config)
824
-
825
- # Read the file with different encodings if needed
826
- content = None
827
- encodings = ['utf-8', 'latin-1', 'cp1252']
828
-
829
- for encoding in encodings:
830
- try:
831
- with open(file_path, 'r', encoding=encoding) as f:
832
- content = f.read()
833
- logger.info(f"Successfully read file with {encoding} encoding")
834
- break
835
- except UnicodeDecodeError:
836
- logger.warning(f"Failed to read with {encoding} encoding, trying next...")
837
-
838
- if content is None:
839
- logger.error(f"Failed to read file with any encoding: {file_path}")
840
- return False
841
-
842
- # Extract metadata
843
- file_name = os.path.basename(file_path)
844
- file_ext = os.path.splitext(file_name)[1].lower()
845
- file_size = os.path.getsize(file_path)
846
- file_mtime = os.path.getmtime(file_path)
847
-
848
- # Try to extract title from content for better reference
849
- title = file_name
850
- try:
851
- # Simple heuristic to find a title (first non-empty line)
852
- lines = content.split('\n')
853
- for line in lines:
854
- line = line.strip()
855
- if line and len(line) < 100: # Reasonable title length
856
- title = line
857
- break
858
- except:
859
- pass
860
-
861
- # Create metadata
862
- metadata = {
863
- 'filename': file_name,
864
- 'title': title,
865
- 'path': file_path,
866
- 'extension': file_ext,
867
- 'size': file_size,
868
- 'modified': datetime.fromtimestamp(file_mtime).isoformat(),
869
- 'created_at': datetime.now().isoformat()
870
- }
871
-
872
- # Generate a unique ID for the document
873
- doc_id = f"{file_name}_{hash(content)}"
874
-
875
- # Add to vector store
876
- success = vector_store.add_document(content, doc_id, metadata)
877
-
878
- logger.info(f"Document loaded successfully: {file_path}" if success else f"Failed to load document: {file_path}")
879
- return success
880
-
881
- except Exception as e:
882
- logger.error(f"Error loading document {file_path}: {e}")
883
- logger.error(traceback.format_exc())
884
- return False
885
 
 
886
  def main():
887
- """Main function to run the RAG application"""
888
- # Path for configuration file
889
- CONFIG_FILE_PATH = "rag_config.json"
 
 
890
 
891
  try:
892
- # Try to load configuration from file, or use defaults
893
- if os.path.exists(CONFIG_FILE_PATH):
894
- config = Config.from_file(CONFIG_FILE_PATH)
895
- else:
896
- config = Config(
897
- local_dir="./chroma_db", # Store Chroma files in dedicated directory
898
- collection_name="markdown_docs"
899
- )
900
- # Save default configuration
901
- config.save_to_file(CONFIG_FILE_PATH)
902
-
903
- print(f"Starting Document Knowledge Assistant v{VERSION}")
904
- print(f"Log file: {log_file}")
905
-
906
  # Initialize vector store manager with existing collection
907
  vector_store = VectorStoreManager(config)
908
 
909
  # Initialize RAG system without API keys initially
910
- rag_system = RAGSystem(vector_store, config)
911
 
912
- # Create the Gradio interface with custom CSS
913
- with gr.Blocks(title="Document Knowledge Assistant", css=custom_css) as app:
914
- gr.Markdown(f"# Document Knowledge Assistant v{VERSION}")
915
- gr.Markdown("Ask questions about your documents and get comprehensive AI-powered answers")
916
 
917
- # Main layout
918
  with gr.Row():
919
- # Left column for asking questions
920
- with gr.Column(scale=3):
921
- with gr.Box():
922
- gr.Markdown("### Ask Your Question")
923
- query_input = gr.Textbox(
924
- label="",
925
- placeholder="What would you like to know about your documents?",
926
- lines=3
927
- )
928
-
929
- with gr.Row():
930
- query_button = gr.Button("Ask Question", variant="primary", scale=3)
931
- clear_button = gr.Button("Clear", variant="secondary", scale=1)
932
-
933
- with gr.Box():
934
- gr.Markdown("### Answer")
935
- response_output = gr.Markdown()
936
-
937
- # Right column for settings
938
  with gr.Column(scale=1):
939
  # API Keys and model selection
940
- with gr.Accordion("AI Model Settings", open=True):
941
- gr.Markdown("### AI Configuration")
942
- model_choice = gr.Radio(
943
- choices=["openai", "gemini"],
944
- value="openai",
945
- label="AI Provider",
946
- info=f"Select your preferred AI model"
947
- )
948
-
949
- api_key_input = gr.Textbox(
950
- label="API Key",
951
- placeholder="Enter your API key here...",
952
- type="password",
953
- info="Your key is not stored between sessions"
954
- )
955
-
956
- save_key_button = gr.Button("Save API Key", variant="primary")
957
- api_status = gr.Markdown("")
958
 
959
- # Advanced search controls
960
- with gr.Accordion("Advanced Settings", open=False):
961
- gr.Markdown("### Search & Response Settings")
962
- num_results = gr.Slider(
963
- minimum=3,
964
- maximum=15,
965
- value=config.default_top_k,
966
- step=1,
967
- label="Documents to search",
968
- info="Higher values provide more context"
969
- )
970
-
971
- temperature_slider = gr.Slider(
972
- minimum=0.0,
973
- maximum=1.0,
974
- value=config.temperature,
975
- step=0.05,
976
- label="Creativity",
977
- info="Lower = more factual, Higher = more creative"
978
- )
979
-
980
- max_tokens_slider = gr.Slider(
981
- minimum=500,
982
- maximum=4000,
983
- value=config.max_tokens,
984
- step=100,
985
- label="Response Length",
986
- info="Maximum words in response"
987
- )
988
 
989
- # Database stats - simplified
990
- with gr.Accordion("System Info", open=False):
991
- stats_display = gr.Markdown(get_db_stats(vector_store))
992
-
993
- gr.Markdown(f"""
994
- **System Details:**
995
- - Version: {VERSION}
996
- - Embedding: {vector_store.embedding_engine.model_name}
997
- - Device: {vector_store.embedding_engine.device}
998
- """)
999
- refresh_button = gr.Button("Refresh", variant="secondary", size="sm")
1000
-
1001
- # Hidden element for search results (not visible to user)
1002
- with gr.Accordion("Debug Information", open=False, visible=False):
1003
- search_output = gr.Markdown()
1004
-
1005
- # Query history at the bottom (optional section)
1006
- with gr.Accordion("Recent Questions", open=False):
1007
- history_list = gr.Dataframe(
1008
- headers=["Time", "Question", "Model"],
1009
- datatype=["str", "str", "str"],
1010
- row_count=5,
1011
- col_count=(3, "fixed"),
1012
- interactive=False
1013
- )
1014
 
1015
- # Footer
1016
- gr.Markdown(
1017
- """<div class="footer">Document Knowledge Assistant helps you get insights from your documents using AI.
1018
- Powered by Retrieval Augmented Generation.</div>"""
1019
- )
1020
-
1021
- # Query history storage
1022
- query_history = []
 
 
 
 
 
 
 
1023
 
1024
  # Function to update API key based on selected model
1025
  def update_api_key(api_key, model):
1026
- if not api_key.strip():
1027
- return "❌ API key cannot be empty"
1028
-
1029
  if model == "openai":
1030
  success = rag_system.setup_openai(api_key)
1031
- model_name = f"OpenAI {config.openai_model}"
1032
  else:
1033
  success = rag_system.setup_gemini(api_key)
1034
- model_name = f"Google {config.gemini_model}"
1035
 
1036
  if success:
1037
- return f"✅ {model_name} connected successfully"
1038
  else:
1039
- return f"❌ Connection failed. Please check your API key and try again."
1040
 
1041
  # Query function that returns both response and search results
1042
- def query_and_search(query, n_results, model, temperature, max_tokens):
1043
- # Update configuration with current UI values
1044
- config.temperature = float(temperature)
1045
- config.max_tokens = int(max_tokens)
1046
 
1047
- start_time = datetime.now()
 
 
 
 
 
 
 
 
1048
 
1049
- if not query.strip():
1050
- return "Please enter a question to get an answer.", "", query_history[-5:] if query_history else []
1051
 
1052
- try:
1053
- # Verify that API keys are configured
1054
- if (model == "openai" and rag_system.openai_client is None) or \
1055
- (model == "gemini" and not rag_system.gemini_configured):
1056
- return "Please configure your API key first. Enter your API key in the settings panel and click 'Save API Key'.", "", query_history[-5:] if query_history else []
1057
-
1058
- # Call the RAG system's query and generate function
1059
- response, search_output_text = rag_system.query_and_generate(
1060
- query=query,
1061
- n_results=int(n_results),
1062
- model=model
1063
- )
1064
-
1065
- # Add to history
1066
- timestamp = datetime.now().strftime("%H:%M")
1067
- query_history.append([timestamp, query, model])
1068
-
1069
- # Keep only the last 100 queries
1070
- if len(query_history) > 100:
1071
- query_history.pop(0)
1072
-
1073
- # Update the history display with the most recent entries (reverse chronological)
1074
- recent_history = list(reversed(query_history[-5:])) if len(query_history) >= 5 else list(reversed(query_history))
1075
-
1076
- # Calculate elapsed time
1077
- elapsed_time = (datetime.now() - start_time).total_seconds()
1078
-
1079
- # Add subtle timing information to the response
1080
- response_with_timing = f"{response}\n\n<small>Answered in {elapsed_time:.1f}s</small>"
1081
-
1082
- return response_with_timing, search_output_text, recent_history
1083
 
1084
- except Exception as e:
1085
- error_msg = f"Error processing query: {str(e)}"
1086
- logger.error(error_msg)
1087
- logger.error(traceback.format_exc())
1088
- return "I encountered an error while processing your question. Please try again or check your API key settings.", "", query_history[-5:] if query_history else []
1089
-
1090
- # Function to clear the input and results
1091
- def clear_inputs():
1092
- return "", "", "", query_history[-5:] if query_history else []
1093
 
1094
  # Set up events
1095
  save_key_button.click(
@@ -1100,8 +448,8 @@ def main():
1100
 
1101
  query_button.click(
1102
  fn=query_and_search,
1103
- inputs=[query_input, num_results, model_choice, temperature_slider, max_tokens_slider],
1104
- outputs=[response_output, search_output, history_list]
1105
  )
1106
 
1107
  refresh_button.click(
@@ -1109,84 +457,24 @@ def main():
1109
  inputs=None,
1110
  outputs=stats_display
1111
  )
1112
-
1113
- clear_button.click(
1114
- fn=clear_inputs,
1115
- inputs=None,
1116
- outputs=[query_input, response_output, search_output, history_list]
1117
- )
1118
-
1119
- # Handle Enter key in query input
1120
- query_input.submit(
1121
- fn=query_and_search,
1122
- inputs=[query_input, num_results, model_choice, temperature_slider, max_tokens_slider],
1123
- outputs=[response_output, search_output, history_list]
1124
- )
1125
-
1126
- # Auto-fill examples
1127
- examples = [
1128
- ["What are the main features of this application?"],
1129
- ["How does the retrieval augmented generation work?"],
1130
- ["Can you explain the embedding models used in this system?"],
1131
- ]
1132
-
1133
- gr.Examples(
1134
- examples=examples,
1135
- inputs=query_input,
1136
- outputs=[response_output, search_output, history_list],
1137
- fn=lambda q: query_and_search(q, num_results.value, model_choice.value, temperature_slider.value, max_tokens_slider.value),
1138
- cache_examples=False,
1139
- )
1140
 
1141
- # Launch the interface with a nice theme
1142
- app.launch(
1143
- share=False, # Set to True to create a public link
1144
- server_name="0.0.0.0", # Listen on all interfaces
1145
- server_port=7860, # Default Gradio port
1146
- debug=False, # Set to True during development
1147
- auth=None, # Add (username, password) tuple for basic auth
1148
- favicon_path="favicon.ico" if os.path.exists("favicon.ico") else None,
1149
- show_error=True
1150
- )
1151
-
1152
  except Exception as e:
1153
- logger.critical(f"Error starting application: {e}")
1154
- print(f"Error starting application: {e}")
1155
  sys.exit(1)
1156
 
 
 
 
 
 
 
 
 
 
 
1157
  if __name__ == "__main__":
1158
- # Parse command line arguments
1159
- if len(sys.argv) > 1:
1160
- if sys.argv[1] == "--load" and len(sys.argv) > 2:
1161
- # Load documents mode
1162
- print(f"Document Knowledge Assistant v{VERSION}")
1163
- print(f"Loading documents into knowledge base...")
1164
-
1165
- success_count = 0
1166
- failed_count = 0
1167
-
1168
- for file_path in sys.argv[2:]:
1169
- if os.path.exists(file_path):
1170
- success = load_document(file_path)
1171
- if success:
1172
- success_count += 1
1173
- print(f"✅ Successfully loaded: {file_path}")
1174
- else:
1175
- failed_count += 1
1176
- print(f"❌ Failed to load: {file_path}")
1177
- else:
1178
- failed_count += 1
1179
- print(f"❌ File not found: {file_path}")
1180
-
1181
- print(f"\nLoading complete: {success_count} documents loaded, {failed_count} failed")
1182
- sys.exit(0)
1183
- elif sys.argv[1] == "--help":
1184
- print(f"Document Knowledge Assistant v{VERSION}")
1185
- print("Usage:")
1186
- print(" python rag_system.py # Start the web UI")
1187
- print(" python rag_system.py --load file1 file2 # Load documents into the knowledge base")
1188
- print(" python rag_system.py --help # Show this help message")
1189
- sys.exit(0)
1190
-
1191
- # Start the web UI
1192
  main()
 
4
  from pathlib import Path
5
  import json
6
  from datetime import datetime
7
+ from typing import List, Dict, Any, Optional
 
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+ logger = logging.getLogger(__name__)
 
12
 
13
+ # Importing necessary libraries
14
+ import torch
15
+ import numpy as np
16
+ from sentence_transformers import SentenceTransformer
17
+ import chromadb
18
+ from chromadb.utils import embedding_functions
19
+ import gradio as gr
20
+ from openai import OpenAI
21
+ import google.generativeai as genai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Configuration class
24
  class Config:
25
+ """Configuration for vector store and RAG"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(self,
27
+ local_dir: str = ".",
28
  embedding_model: str = "all-MiniLM-L6-v2",
29
+ collection_name: str = "markdown_docs"):
 
 
 
 
 
 
 
30
  self.local_dir = local_dir
31
  self.embedding_model = embedding_model
32
  self.collection_name = collection_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Embedding engine
35
  class EmbeddingEngine:
36
+ """Handle embeddings with a lightweight model"""
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def __init__(self, model_name="all-MiniLM-L6-v2"):
 
 
 
 
 
 
 
 
 
39
  # Use GPU if available
40
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
41
+ logger.info(f"Using device: {self.device}")
42
 
43
  # Try multiple model options in order of preference
44
  model_options = [
45
  model_name,
46
+ "all-MiniLM-L6-v2",
47
+ "paraphrase-MiniLM-L3-v2",
48
+ "all-mpnet-base-v2" # Higher quality but larger model
49
  ]
50
 
51
  self.model = None
 
53
  # Try each model in order until one works
54
  for model_option in model_options:
55
  try:
56
+ logger.info(f"Attempting to load model: {model_option}")
57
  self.model = SentenceTransformer(model_option)
58
 
59
  # Move model to device
60
  self.model.to(self.device)
61
 
62
+ logger.info(f"Successfully loaded model: {model_option}")
63
  self.model_name = model_option
64
  self.vector_size = self.model.get_sentence_embedding_dimension()
 
65
  break
66
 
67
  except Exception as e:
68
+ logger.warning(f"Failed to load model {model_option}: {str(e)}")
69
 
70
  if self.model is None:
71
+ logger.error("Failed to load any embedding model. Exiting.")
72
+ sys.exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  class VectorStoreManager:
75
+ """Manage Chroma vector store operations - upload, query, etc."""
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def __init__(self, config: Config):
 
 
 
 
 
 
 
 
 
78
  self.config = config
79
 
80
  # Initialize Chroma client (local persistence)
81
  logger.info(f"Initializing Chroma at {config.local_dir}")
82
+ self.client = chromadb.PersistentClient(path=config.local_dir)
 
 
 
 
 
 
83
 
84
  # Get or create collection
85
  try:
86
  # Initialize embedding model
87
  logger.info("Loading embedding model...")
88
  self.embedding_engine = EmbeddingEngine(config.embedding_model)
89
+ logger.info(f"Using model: {self.embedding_engine.model_name}")
90
 
91
  # Create embedding function
92
  sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
93
  model_name=self.embedding_engine.model_name
94
  )
95
 
96
+ # Try to get existing collection
97
  try:
98
  self.collection = self.client.get_collection(
99
  name=config.collection_name,
 
101
  )
102
  logger.info(f"Using existing collection: {config.collection_name}")
103
  except Exception as e:
104
+ logger.error(f"Error getting collection: {e}")
105
  # Attempt to get a list of available collections
106
  collections = self.client.list_collections()
107
  if collections:
 
122
  logger.info(f"Created new collection: {config.collection_name}")
123
 
124
  except Exception as e:
125
+ logger.error(f"Error initializing Chroma collection: {e}")
126
+ sys.exit(1)
 
127
 
128
  def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
129
  """
130
+ Query the vector store with a text query
 
 
 
 
 
 
 
131
  """
 
 
 
 
132
  try:
 
 
133
  # Query the collection
134
  search_results = self.collection.query(
135
  query_texts=[query_text],
 
143
  for i in range(len(search_results["documents"][0])):
144
  results.append({
145
  'document': search_results["documents"][0][i],
146
+ 'metadata': search_results["metadatas"][0][i],
147
+ 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity
 
148
  })
 
 
 
 
149
 
150
  return results
151
  except Exception as e:
152
  logger.error(f"Error querying collection: {e}")
 
153
  return []
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def get_statistics(self) -> Dict[str, Any]:
156
+ """Get statistics about the vector store"""
157
+ stats = {}
 
 
 
 
 
 
 
 
 
 
158
 
159
  try:
160
  # Get collection count
161
+ collection_info = self.collection.count()
162
+ stats['total_documents'] = collection_info
163
 
164
+ # Estimate unique files - with no chunking, each document is a file
165
+ stats['unique_files'] = collection_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  except Exception as e:
167
  logger.error(f"Error getting statistics: {e}")
168
  stats['error'] = str(e)
 
170
  return stats
171
 
172
  class RAGSystem:
173
+ """Retrieval-Augmented Generation with multiple LLM providers"""
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ def __init__(self, vector_store: VectorStoreManager):
 
 
 
 
 
 
 
176
  self.vector_store = vector_store
 
177
  self.openai_client = None
178
  self.gemini_configured = False
 
 
179
 
180
+ def setup_openai(self, api_key: str):
181
+ """Set up OpenAI client with API key"""
 
 
 
 
 
 
 
 
 
 
 
 
182
  try:
 
183
  self.openai_client = OpenAI(api_key=api_key)
 
 
 
 
 
 
 
 
 
 
184
  return True
185
  except Exception as e:
186
  logger.error(f"Error initializing OpenAI client: {e}")
 
187
  return False
188
 
189
+ def setup_gemini(self, api_key: str):
190
+ """Set up Gemini with API key"""
 
 
 
 
 
 
 
 
 
 
 
 
191
  try:
 
192
  genai.configure(api_key=api_key)
 
 
 
 
 
193
  self.gemini_configured = True
 
194
  return True
195
  except Exception as e:
196
  logger.error(f"Error configuring Gemini: {e}")
 
197
  return False
198
 
199
  def format_context(self, documents: List[Dict]) -> str:
200
+ """Format retrieved documents into context for the LLM"""
 
 
 
 
 
 
 
 
201
  if not documents:
 
202
  return "No relevant documents found."
203
 
 
204
  context_parts = []
 
205
  for i, doc in enumerate(documents):
206
  metadata = doc['metadata']
 
207
  title = metadata.get('title', metadata.get('filename', 'Unknown document'))
208
 
209
+ # For readability, limit length of context document
 
 
 
210
  doc_text = doc['document']
211
+ if len(doc_text) > 10000: # Limit long documents in context
212
+ doc_text = doc_text[:10000] + "... [Document truncated for context]"
 
 
213
 
214
+ context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
215
 
216
+ return "\n".join(context_parts)
 
 
 
217
 
218
  def generate_response_openai(self, query: str, context: str) -> str:
219
+ """Generate a response using OpenAI model with context"""
 
 
 
 
 
 
 
 
 
220
  if not self.openai_client:
221
+ return "Error: OpenAI API key not configured. Please enter an API key in the API key field."
 
222
 
 
223
  system_prompt = """
224
+ You are a helpful assistant that answers questions based on the context provided.
225
+ Use the information from the context to answer the user's question.
226
+ If the context doesn't contain the information needed, say so clearly.
227
+ Always cite the specific sections from the context that you used in your answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  """
229
 
230
  try:
 
 
 
231
  response = self.openai_client.chat.completions.create(
232
+ model="gpt-4o-mini", # Use GPT-4o mini
233
  messages=[
234
  {"role": "system", "content": system_prompt},
235
  {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
236
  ],
237
+ temperature=0.3, # Lower temperature for more factual responses
238
+ max_tokens=5000,
239
  )
240
+ return response.choices[0].message.content
 
 
 
 
 
241
  except Exception as e:
242
+ logger.error(f"Error generating response with OpenAI: {e}")
243
+ return f"Error generating response with OpenAI: {str(e)}"
 
244
 
245
  def generate_response_gemini(self, query: str, context: str) -> str:
246
+ """Generate a response using Gemini with context"""
 
 
 
 
 
 
 
 
 
247
  if not self.gemini_configured:
248
+ return "Error: Google AI API key not configured. Please enter an API key in the API key field."
 
249
 
 
250
  prompt = f"""
 
 
 
251
 
252
+ <prompt>
253
+ <system>
254
+ <name>Loss Dog</name>
255
+ <role>You are a highly intelligent AI specializing in labor market analysis, job trends, and skillset forecasting. You utilize a combination of structured data from sources like the Bureau of Labor Statistics (BLS) and the World Economic Forum (WEF), alongside advanced retrieval-augmented generation (RAG) techniques.</role>
256
+ <goal>Your mission is to provide insightful, data-driven, and comprehensive answers to users seeking career and job market intelligence. You must ensure clarity, depth, and practical relevance in all responses.</goal>
257
+ <personality>
258
+ <tone>Friendly, professional, and engaging</tone>
259
+ <depth>Detailed, nuanced, and well-explained</depth>
260
+ <clarity>Well-structured with headings, citations, and easy-to-follow breakdowns</clarity>
261
+ </personality>
262
+ <methodology>
263
+ <data_sources>
264
+ <source>Bureau of Labor Statistics (BLS)</source>
265
+ <source>World Economic Forum (WEF) reports</source>
266
+ <source>Market research studies</source>
267
+ <source>Industry whitepapers</source>
268
+ <source>Company hiring trends</source>
269
+ </data_sources>
270
+ <reasoning_strategy>
271
+ <if_data_available>
272
+ <response>
273
+ Use precise statistics, industry insights, and expert analyses from retrieved sources to craft an evidence-based answer.
274
+ </response>
275
+ </if_data_available>
276
+ <if_data_unavailable>
277
+ <response>
278
+ Clearly state that the exact data is unavailable. However, provide a **comprehensive explanation** using logical deduction, adjacent industry trends, historical patterns, and economic principles.
279
+ </response>
280
+ </if_data_unavailable>
281
+ </reasoning_strategy>
282
+ <output_expectations>
283
+ <length>100-500 words, depending on complexity and sources available</length>
284
+ <structure>
285
+ <section>Introduction (sets context and purpose)</section>
286
+ <section>Data-backed analysis (citing retrieved sources)</section>
287
+ <section>Logical deduction and reasoning (when necessary)</section>
288
+ <section>Conclusion (summarizes insights and provides actionable takeaways)</section>
289
+ </structure>
290
+ <citation_style>Clearly cite data sources within the response (e.g., "According to BLS 2024 report...").</citation_style>
291
+ <engagement>Encourage follow-up questions and deeper exploration where relevant.</engagement>
292
+ </output_expectations>
293
+ </methodology>
294
+ </system>
295
+ Context:
296
  {context}
297
 
298
+ Question: {query}
 
 
 
299
  """
300
+
301
  try:
302
+ model = genai.GenerativeModel('gemini-1.5-flash')
303
+ response = model.generate_content(prompt)
304
+ return response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  except Exception as e:
306
+ logger.error(f"Error generating response with Gemini: {e}")
307
+ return f"Error generating response with Gemini: {str(e)}"
 
308
 
309
+ def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
310
+ """Retrieve relevant documents and generate a response using the specified model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # Query vector store
312
  documents = self.vector_store.query(query, n_results=n_results)
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  if not documents:
315
+ return "No relevant documents found to answer your question."
 
316
 
317
  # Format context
318
  context = self.format_context(documents)
319
 
320
  # Generate response with the appropriate model
321
  if model == "openai":
322
+ return self.generate_response_openai(query, context)
323
  elif model == "gemini":
324
+ return self.generate_response_gemini(query, context)
325
  else:
326
+ return f"Unknown model: {model}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
+ # Main function to run the application
329
  def main():
330
+ # Initialize the system with current directory as the Chroma location
331
+ config = Config(
332
+ local_dir=".", # Look for Chroma files in current directory
333
+ collection_name="markdown_docs"
334
+ )
335
 
336
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  # Initialize vector store manager with existing collection
338
  vector_store = VectorStoreManager(config)
339
 
340
  # Initialize RAG system without API keys initially
341
+ rag_system = RAGSystem(vector_store)
342
 
343
+ # Create the Gradio interface
344
+ with gr.Blocks(title="Document RAG System") as app:
345
+ gr.Markdown("# Document RAG System")
 
346
 
 
347
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  with gr.Column(scale=1):
349
  # API Keys and model selection
350
+ model_choice = gr.Radio(
351
+ choices=["openai", "gemini"],
352
+ value="openai",
353
+ label="Choose LLM Provider",
354
+ info="Select which model to use (GPT-4o mini or Gemini 1.5 Flash)"
355
+ )
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ api_key_input = gr.Textbox(
358
+ label="API Key",
359
+ placeholder="Enter your API key here...",
360
+ type="password"
361
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
+ save_key_button = gr.Button("Save API Key", variant="primary")
364
+ api_status = gr.Markdown("")
365
+
366
+ # Search controls
367
+ num_results = gr.Slider(
368
+ minimum=1,
369
+ maximum=10,
370
+ value=10,
371
+ step=1,
372
+ label="Number of documents to retrieve"
373
+ )
374
+
375
+ # Database stats
376
+ gr.Markdown("### Database Statistics")
377
+ stats_display = gr.Textbox(
378
+ label="",
379
+ value=get_db_stats(vector_store),
380
+ lines=2
381
+ )
382
+ refresh_button = gr.Button("Refresh Stats")
 
 
 
 
 
383
 
384
+ with gr.Column(scale=2):
385
+ # Query and response
386
+ query_input = gr.Textbox(
387
+ label="Your Question",
388
+ placeholder="Ask a question about your documents...",
389
+ lines=2
390
+ )
391
+
392
+ query_button = gr.Button("Ask Question", variant="primary")
393
+
394
+ gr.Markdown("### Response")
395
+ response_output = gr.Markdown()
396
+
397
+ gr.Markdown("### Document Search Results")
398
+ search_output = gr.Markdown()
399
 
400
  # Function to update API key based on selected model
401
  def update_api_key(api_key, model):
 
 
 
402
  if model == "openai":
403
  success = rag_system.setup_openai(api_key)
404
+ model_name = "OpenAI GPT-4o mini"
405
  else:
406
  success = rag_system.setup_gemini(api_key)
407
+ model_name = "Google Gemini 1.5 Flash"
408
 
409
  if success:
410
+ return f"✅ {model_name} API key configured successfully"
411
  else:
412
+ return f"❌ Failed to configure {model_name} API key"
413
 
414
  # Query function that returns both response and search results
415
+ def query_and_search(query, n_results, model):
416
+ # Get search results first
417
+ results = vector_store.query(query, n_results=int(n_results))
 
418
 
419
+ # Format search results
420
+ formatted_results = []
421
+ for i, res in enumerate(results):
422
+ metadata = res['metadata']
423
+ title = metadata.get('title', metadata.get('filename', 'Unknown'))
424
+ preview = res['document'][:500] + '...' if len(res['document']) > 500 else res['document']
425
+ formatted_results.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n"
426
+ f"**Source:** {title}\n"
427
+ f"**Preview:**\n{preview}\n\n---\n")
428
 
429
+ search_output_text = "\n".join(formatted_results) if formatted_results else "No results found."
 
430
 
431
+ # Generate response if we have results
432
+ response = "No documents found to answer your question."
433
+ if results:
434
+ context = rag_system.format_context(results)
435
+ if model == "openai":
436
+ response = rag_system.generate_response_openai(query, context)
437
+ else:
438
+ response = rag_system.generate_response_gemini(query, context)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
+ return response, search_output_text
 
 
 
 
 
 
 
 
441
 
442
  # Set up events
443
  save_key_button.click(
 
448
 
449
  query_button.click(
450
  fn=query_and_search,
451
+ inputs=[query_input, num_results, model_choice],
452
+ outputs=[response_output, search_output]
453
  )
454
 
455
  refresh_button.click(
 
457
  inputs=None,
458
  outputs=stats_display
459
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
+ # Launch the interface
462
+ app.launch()
463
+
 
 
 
 
 
 
 
 
464
  except Exception as e:
465
+ logger.error(f"Error initializing application: {e}")
466
+ print(f"Error: {e}")
467
  sys.exit(1)
468
 
469
+ # Helper function to get database stats
470
+ def get_db_stats(vector_store):
471
+ """Function to get vector store statistics"""
472
+ try:
473
+ stats = vector_store.get_statistics()
474
+ return f"Total documents: {stats.get('total_documents', 0)}"
475
+ except Exception as e:
476
+ logger.error(f"Error getting statistics: {e}")
477
+ return "Error getting database statistics"
478
+
479
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  main()