jzou19950715 commited on
Commit
c29c31c
·
verified ·
1 Parent(s): 34c8f16

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -467
app.py DELETED
@@ -1,467 +0,0 @@
1
- import os
2
- import sys
3
- import logging
4
- from pathlib import Path
5
- import json
6
- import hashlib
7
- from datetime import datetime
8
- import threading
9
- import queue
10
- from typing import List, Dict, Any, Tuple, Optional
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
- logger = logging.getLogger(__name__)
15
-
16
- # Importing necessary libraries
17
- import torch
18
- import numpy as np
19
- from sentence_transformers import SentenceTransformer
20
- import chromadb
21
- from chromadb.utils import embedding_functions
22
- import gradio as gr
23
- from openai import OpenAI
24
- import google.generativeai as genai
25
-
26
- # Configuration class
27
- class Config:
28
- """Configuration for vector store and RAG"""
29
- def __init__(self,
30
- local_dir: str = "./chroma_data",
31
- batch_size: int = 20,
32
- max_workers: int = 4,
33
- embedding_model: str = "all-MiniLM-L6-v2",
34
- collection_name: str = "markdown_docs"):
35
- self.local_dir = local_dir
36
- self.batch_size = batch_size
37
- self.max_workers = max_workers
38
- self.checkpoint_file = Path(local_dir) / "checkpoint.json"
39
- self.embedding_model = embedding_model
40
- self.collection_name = collection_name
41
-
42
- # Create local directory for checkpoints and Chroma
43
- Path(local_dir).mkdir(parents=True, exist_ok=True)
44
-
45
- # Embedding engine
46
- class EmbeddingEngine:
47
- """Handle embeddings with a lightweight model"""
48
-
49
- def __init__(self, model_name="all-MiniLM-L6-v2"):
50
- # Use GPU if available
51
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
- logger.info(f"Using device: {self.device}")
53
-
54
- # Try multiple model options in order of preference
55
- model_options = [
56
- model_name,
57
- "all-MiniLM-L6-v2",
58
- "paraphrase-MiniLM-L3-v2",
59
- "all-mpnet-base-v2" # Higher quality but larger model
60
- ]
61
-
62
- self.model = None
63
-
64
- # Try each model in order until one works
65
- for model_option in model_options:
66
- try:
67
- logger.info(f"Attempting to load model: {model_option}")
68
- self.model = SentenceTransformer(model_option)
69
-
70
- # Move model to device
71
- self.model.to(self.device)
72
-
73
- logger.info(f"Successfully loaded model: {model_option}")
74
- self.model_name = model_option
75
- self.vector_size = self.model.get_sentence_embedding_dimension()
76
- break
77
-
78
- except Exception as e:
79
- logger.warning(f"Failed to load model {model_option}: {str(e)}")
80
-
81
- if self.model is None:
82
- logger.error("Failed to load any embedding model. Exiting.")
83
- sys.exit(1)
84
-
85
- def encode(self, text, batch_size=32):
86
- """Get embedding for a text or list of texts"""
87
- # Handle single text
88
- if isinstance(text, str):
89
- texts = [text]
90
- else:
91
- texts = text
92
-
93
- # Truncate texts if necessary to avoid tokenization issues
94
- truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts]
95
-
96
- # Generate embeddings
97
- try:
98
- embeddings = self.model.encode(truncated_texts, batch_size=batch_size,
99
- show_progress_bar=False, convert_to_numpy=True)
100
- return embeddings
101
- except Exception as e:
102
- logger.error(f"Error generating embeddings: {e}")
103
- # Return zero embeddings as fallback
104
- return np.zeros((len(truncated_texts), self.vector_size))
105
-
106
- class VectorStoreManager:
107
- """Manage Chroma vector store operations - upload, query, etc."""
108
-
109
- def __init__(self, config: Config):
110
- self.config = config
111
-
112
- # Initialize Chroma client (local persistence)
113
- logger.info(f"Initializing Chroma at {config.local_dir}")
114
- self.client = chromadb.PersistentClient(path=config.local_dir)
115
-
116
- # Get or create collection
117
- try:
118
- # Initialize embedding model
119
- logger.info("Loading embedding model...")
120
- self.embedding_engine = EmbeddingEngine(config.embedding_model)
121
- logger.info(f"Using model: {self.embedding_engine.model_name}")
122
-
123
- # Create embedding function
124
- sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
125
- model_name=self.embedding_engine.model_name
126
- )
127
-
128
- # Try to get existing collection
129
- try:
130
- self.collection = self.client.get_collection(
131
- name=config.collection_name,
132
- embedding_function=sentence_transformer_ef
133
- )
134
- logger.info(f"Using existing collection: {config.collection_name}")
135
- except:
136
- # Create new collection if it doesn't exist
137
- self.collection = self.client.create_collection(
138
- name=config.collection_name,
139
- embedding_function=sentence_transformer_ef,
140
- metadata={"hnsw:space": "cosine"}
141
- )
142
- logger.info(f"Created new collection: {config.collection_name}")
143
-
144
- except Exception as e:
145
- logger.error(f"Error initializing Chroma collection: {e}")
146
- sys.exit(1)
147
-
148
- def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
149
- """
150
- Query the vector store with a text query
151
- """
152
- try:
153
- # Query the collection
154
- search_results = self.collection.query(
155
- query_texts=[query_text],
156
- n_results=n_results,
157
- include=["documents", "metadatas", "distances"]
158
- )
159
-
160
- # Format results
161
- results = []
162
- if search_results["documents"] and len(search_results["documents"][0]) > 0:
163
- for i in range(len(search_results["documents"][0])):
164
- results.append({
165
- 'document': search_results["documents"][0][i],
166
- 'metadata': search_results["metadatas"][0][i],
167
- 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity
168
- })
169
-
170
- return results
171
- except Exception as e:
172
- logger.error(f"Error querying collection: {e}")
173
- return []
174
-
175
- def get_statistics(self) -> Dict[str, Any]:
176
- """Get statistics about the vector store"""
177
- stats = {}
178
-
179
- try:
180
- # Get collection count
181
- collection_info = self.collection.count()
182
- stats['total_documents'] = collection_info
183
-
184
- # Estimate unique files - with no chunking, each document is a file
185
- stats['unique_files'] = collection_info
186
- except Exception as e:
187
- logger.error(f"Error getting statistics: {e}")
188
- stats['error'] = str(e)
189
-
190
- return stats
191
-
192
- class RAGSystem:
193
- """Retrieval-Augmented Generation with multiple LLM providers"""
194
-
195
- def __init__(self, vector_store: VectorStoreManager):
196
- self.vector_store = vector_store
197
- self.openai_client = None
198
- self.gemini_configured = False
199
-
200
- def setup_openai(self, api_key: str):
201
- """Set up OpenAI client with API key"""
202
- try:
203
- self.openai_client = OpenAI(api_key=api_key)
204
- return True
205
- except Exception as e:
206
- logger.error(f"Error initializing OpenAI client: {e}")
207
- return False
208
-
209
- def setup_gemini(self, api_key: str):
210
- """Set up Gemini with API key"""
211
- try:
212
- genai.configure(api_key=api_key)
213
- self.gemini_configured = True
214
- return True
215
- except Exception as e:
216
- logger.error(f"Error configuring Gemini: {e}")
217
- return False
218
-
219
- def format_context(self, documents: List[Dict]) -> str:
220
- """Format retrieved documents into context for the LLM"""
221
- if not documents:
222
- return "No relevant documents found."
223
-
224
- context_parts = []
225
- for i, doc in enumerate(documents):
226
- metadata = doc['metadata']
227
- title = metadata.get('title', metadata.get('filename', 'Unknown document'))
228
-
229
- # For readability, limit length of context document
230
- doc_text = doc['document']
231
- if len(doc_text) > 10000: # Limit long documents in context
232
- doc_text = doc_text[:10000] + "... [Document truncated for context]"
233
-
234
- context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
235
-
236
- return "\n".join(context_parts)
237
-
238
- def generate_response_openai(self, query: str, context: str) -> str:
239
- """Generate a response using OpenAI model with context"""
240
- if not self.openai_client:
241
- return "Error: OpenAI API key not configured. Please enter an API key in the settings tab."
242
-
243
- system_prompt = """
244
- You are a helpful assistant that answers questions based on the context provided.
245
- Use the information from the context to answer the user's question.
246
- If the context doesn't contain the information needed, say so clearly.
247
- Always cite the specific sections from the context that you used in your answer.
248
- """
249
-
250
- try:
251
- response = self.openai_client.chat.completions.create(
252
- model="gpt-4o-mini", # Use GPT-4o mini
253
- messages=[
254
- {"role": "system", "content": system_prompt},
255
- {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
256
- ],
257
- temperature=0.3, # Lower temperature for more factual responses
258
- max_tokens=1000,
259
- )
260
- return response.choices[0].message.content
261
- except Exception as e:
262
- logger.error(f"Error generating response with OpenAI: {e}")
263
- return f"Error generating response with OpenAI: {str(e)}"
264
-
265
- def generate_response_gemini(self, query: str, context: str) -> str:
266
- """Generate a response using Gemini with context"""
267
- if not self.gemini_configured:
268
- return "Error: Google AI API key not configured. Please enter an API key in the settings tab."
269
-
270
- prompt = f"""
271
- You are a helpful assistant that answers questions based on the context provided.
272
- Use the information from the context to answer the user's question.
273
- If the context doesn't contain the information needed, say so clearly.
274
- Always cite the specific sections from the context that you used in your answer.
275
-
276
- Context:
277
- {context}
278
-
279
- Question: {query}
280
- """
281
-
282
- try:
283
- model = genai.GenerativeModel('gemini-1.5-flash')
284
- response = model.generate_content(prompt)
285
- return response.text
286
- except Exception as e:
287
- logger.error(f"Error generating response with Gemini: {e}")
288
- return f"Error generating response with Gemini: {str(e)}"
289
-
290
- def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
291
- """Retrieve relevant documents and generate a response using the specified model"""
292
- # Query vector store
293
- documents = self.vector_store.query(query, n_results=n_results)
294
-
295
- if not documents:
296
- return "No relevant documents found to answer your question."
297
-
298
- # Format context
299
- context = self.format_context(documents)
300
-
301
- # Generate response with the appropriate model
302
- if model == "openai":
303
- return self.generate_response_openai(query, context)
304
- elif model == "gemini":
305
- return self.generate_response_gemini(query, context)
306
- else:
307
- return f"Unknown model: {model}"
308
-
309
- def rag_chat(query, n_results, model_choice, rag_system):
310
- """Function to handle RAG chat queries"""
311
- return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice)
312
-
313
- def simple_query(query, n_results, vector_store):
314
- """Function to handle simple vector store queries"""
315
- results = vector_store.query(query, n_results=int(n_results))
316
-
317
- # Format results for display
318
- formatted = []
319
- for i, res in enumerate(results):
320
- metadata = res['metadata']
321
- title = metadata.get('title', metadata.get('filename', 'Unknown'))
322
- # Limit preview text for display
323
- preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document']
324
- formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n"
325
- f"**Source:** {title}\n\n"
326
- f"**Content:**\n{preview}\n\n"
327
- f"---\n")
328
-
329
- return "\n".join(formatted) if formatted else "No results found."
330
-
331
- def get_db_stats(vector_store):
332
- """Function to get vector store statistics"""
333
- stats = vector_store.get_statistics()
334
- return (f"Total documents: {stats.get('total_documents', 0)}\n"
335
- f"Unique files: {stats.get('unique_files', 0)}")
336
-
337
- def update_api_keys(openai_key, gemini_key, rag_system):
338
- """Update API keys for the RAG system"""
339
- success_msg = []
340
-
341
- if openai_key:
342
- if rag_system.setup_openai(openai_key):
343
- success_msg.append("✅ OpenAI API key configured successfully")
344
- else:
345
- success_msg.append("❌ Failed to configure OpenAI API key")
346
-
347
- if gemini_key:
348
- if rag_system.setup_gemini(gemini_key):
349
- success_msg.append("✅ Google AI API key configured successfully")
350
- else:
351
- success_msg.append("❌ Failed to configure Google AI API key")
352
-
353
- if not success_msg:
354
- return "Please enter at least one API key"
355
-
356
- return "\n".join(success_msg)
357
-
358
- # Main function to run the application
359
- def main():
360
- # Set up paths for existing Chroma database
361
- chroma_dir = Path("./chroma_data")
362
-
363
- # Initialize the system
364
- config = Config(
365
- local_dir=str(chroma_dir),
366
- collection_name="markdown_docs"
367
- )
368
-
369
- # Initialize vector store manager with existing collection
370
- vector_store = VectorStoreManager(config)
371
-
372
- # Initialize RAG system without API keys initially
373
- rag_system = RAGSystem(vector_store)
374
-
375
- # Define Gradio app
376
- def rag_chat_wrapper(query, n_results, model_choice):
377
- return rag_chat(query, n_results, model_choice, rag_system)
378
-
379
- def simple_query_wrapper(query, n_results):
380
- return simple_query(query, n_results, vector_store)
381
-
382
- def update_api_keys_wrapper(openai_key, gemini_key):
383
- return update_api_keys(openai_key, gemini_key, rag_system)
384
-
385
- # Create the Gradio interface
386
- with gr.Blocks(title="Markdown RAG System") as app:
387
- gr.Markdown("# RAG System with Multiple LLM Providers")
388
-
389
- with gr.Tab("Chat with Documents"):
390
- with gr.Row():
391
- with gr.Column(scale=3):
392
- query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...")
393
- num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve")
394
- model_choice = gr.Radio(
395
- choices=["openai", "gemini"],
396
- value="openai",
397
- label="Choose LLM Provider",
398
- info="Select which model to use for generating answers"
399
- )
400
- query_button = gr.Button("Ask", variant="primary")
401
-
402
- with gr.Column(scale=7):
403
- response_output = gr.Markdown(label="Response")
404
-
405
- # Database stats
406
- stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store))
407
- refresh_button = gr.Button("Refresh Statistics")
408
-
409
- with gr.Tab("Document Search"):
410
- search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...")
411
- search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
412
- search_button = gr.Button("Search", variant="primary")
413
- search_output = gr.Markdown(label="Search Results")
414
-
415
- with gr.Tab("Settings"):
416
- gr.Markdown("""
417
- ## API Keys Configuration
418
-
419
- This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses.
420
- You need to provide at least one API key to use the chat functionality.
421
- """)
422
-
423
- openai_key_input = gr.Textbox(
424
- label="OpenAI API Key",
425
- placeholder="Enter your OpenAI API key here...",
426
- type="password"
427
- )
428
-
429
- gemini_key_input = gr.Textbox(
430
- label="Google AI API Key",
431
- placeholder="Enter your Google AI API key here...",
432
- type="password"
433
- )
434
-
435
- save_keys_button = gr.Button("Save API Keys", variant="primary")
436
- api_status = gr.Markdown("")
437
-
438
- # Set up events
439
- query_button.click(
440
- fn=rag_chat_wrapper,
441
- inputs=[query_input, num_results, model_choice],
442
- outputs=response_output
443
- )
444
-
445
- refresh_button.click(
446
- fn=lambda: get_db_stats(vector_store),
447
- inputs=None,
448
- outputs=stats_display
449
- )
450
-
451
- search_button.click(
452
- fn=simple_query_wrapper,
453
- inputs=[search_input, search_num],
454
- outputs=search_output
455
- )
456
-
457
- save_keys_button.click(
458
- fn=update_api_keys_wrapper,
459
- inputs=[openai_key_input, gemini_key_input],
460
- outputs=api_status
461
- )
462
-
463
- # Launch the interface
464
- app.launch()
465
-
466
- if __name__ == "__main__":
467
- main()