jzou19950715 commited on
Commit
34c8f16
Β·
verified Β·
1 Parent(s): f447aec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +424 -341
app.py CHANGED
@@ -1,384 +1,467 @@
1
- import base64
2
- import io
3
  import os
4
- import gradio as gr
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- from typing import Dict, List, Tuple, Any
8
- import json
9
- from litellm import completion
10
  import logging
 
 
 
 
 
 
 
11
 
12
  # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- CONVERSATION_PROMPT = """
17
- You are an engaging and insightful career advisor. Have natural conversations to learn about their career.
18
- Use an enthusiastic, supportive tone and show genuine interest in their journey.
19
- CONVERSATION STYLE:
20
- - Be warm and engaging
21
- - Show genuine interest in their experiences
22
- - Ask specific follow-up questions about details they mention
23
- - Keep the conversation flowing naturally
24
- - Use conversational language, not formal queries
25
- - Express enthusiasm about their achievements
26
- - Dig deeper into interesting points they make
27
- INFORMATION TO GATHER (through natural conversation):
28
- 1. Current Role Details:
29
- - Job title and responsibilities
30
- - Company size and industry
31
- - Team size and structure
32
- - Project scope and impact
33
- - Current compensation (base, bonus, equity)
34
- 2. Experience Deep-Dive:
35
- - Career progression story
36
- - Leadership experience
37
- - Major projects and achievements
38
- - Technical skills and expertise
39
- - Industry knowledge
40
- 3. Educational Background:
41
- - Degrees and certifications
42
- - Specialized training
43
- - Continuous learning
44
- 4. Work Environment:
45
- - Location and market
46
- - Remote/hybrid setup
47
- - Growth opportunities
48
- - Company culture
49
- CONVERSATION FLOW:
50
- 1. Start with: "Hi! I'd love to hear about your career journey. What kind of work are you doing currently?"
51
- 2. After each response:
52
- - Pick up on specific details they mentioned
53
- - Ask engaging follow-up questions
54
- - Show genuine interest in their experiences
55
- - Build on previous information shared
56
- 3. If they mention something interesting, probe deeper:
57
- - "That project sounds fascinating! What were some unique challenges you faced?"
58
- - "Leading a team must be exciting! How did you approach building and motivating your team?"
59
- - "Interesting technology stack! What made you choose those specific tools?"
60
- 4. When compensation is mentioned:
61
- - Be tactful and professional
62
- - Acknowledge their goals
63
- - Ask about their desired growth
64
- 5. Once you have enough information, say:
65
- "I've got a good understanding of your career profile now! Would you like to see your personalized salary growth projection? Just click 'Generate Analysis' and I'll create a detailed forecast based on our discussion."
66
- IMPORTANT:
67
- - Keep conversation flowing naturally
68
- - Don't rush to collect information
69
- - Show genuine interest in their story
70
- - Ask insightful follow-up questions
71
- - Build rapport through discussion
72
- """
73
 
74
- EXTRACTION_PROMPT = """
75
- Analyze the conversation and extract numerical scores from 0 to 1 based on salary growth potential.
76
- SCORING GUIDELINES:
77
- 1. Industry Score (0-1):
78
- Industry Type & Growth:
79
- - 1.0: Cutting-edge AI/ML companies
80
- - 0.9: High-growth tech (cloud, cybersecurity)
81
- - 0.8: Established tech companies
82
- - 0.7: Finance/Healthcare tech
83
- - 0.6: Traditional tech sectors
84
- - 0.5: Non-tech industries
85
- Company Position:
86
- +0.1: Market leader
87
- +0.1: High growth trajectory
88
- -0.1: Declining market position
89
- 2. Experience Score (0-1):
90
- Years and Level:
91
- - 1.0: 15+ years with executive experience
92
- - 0.9: 10-15 years, senior leadership
93
- - 0.8: 7-10 years, team leadership
94
- - 0.7: 4-6 years, senior individual
95
- - 0.6: 2-3 years, mid-level
96
- - 0.5: 0-1 years, entry-level
97
- Quality Indicators:
98
- +0.1: Rapid promotions
99
- +0.1: Significant achievements
100
- +0.1: High-impact projects
101
- 3. Education Score (0-1):
102
- Formal Education:
103
- - 1.0: PhD from top institution
104
- - 0.9: Masters from top institution
105
- - 0.8: Bachelors from top institution
106
- - 0.7: Advanced degree
107
- - 0.6: Bachelors degree
108
- - 0.5: Other education
109
- Additional Factors:
110
- +0.1: Relevant certifications
111
- +0.1: Continuous learning
112
- +0.1: Field-specific expertise
113
- 4. Skills Score (0-1):
114
- Technical Depth:
115
- - 1.0: Industry-leading expertise
116
- - 0.9: Advanced technical leadership
117
- - 0.8: Strong technical + leadership
118
- - 0.7: Solid technical skills
119
- - 0.6: Growing technical skills
120
- - 0.5: Basic skill set
121
- Breadth and Application:
122
- +0.1: Multiple in-demand skills
123
- +0.1: Proven implementation
124
- +0.1: Cross-functional expertise
125
- 5. Location Score (0-1):
126
- Market Strength:
127
- - 1.0: Major tech hubs (SF, NYC)
128
- - 0.9: Growing tech hubs
129
- - 0.8: Major cities
130
- - 0.7: Regional tech centers
131
- - 0.6: Smaller markets
132
- - 0.5: Remote locations
133
- Flexibility:
134
- +0.1: Remote work option
135
- +0.1: High-growth market
136
- +0.1: Strategic location
137
- Return a JSON object with exactly these fields:
138
- {
139
- "industry_score": float,
140
- "experience_score": float,
141
- "education_score": float,
142
- "skills_score": float,
143
- "location_score": float,
144
- "current_salary": float
145
- }
146
- Base scores on available information. Make reasonable assumptions for missing data based on context clues.
147
- """
148
 
149
- class CodeEnvironment:
150
- """Environment for executing visualization code"""
151
-
152
- def __init__(self):
153
- self.globals = {'np': np, 'plt': plt}
154
- self.locals = {}
155
 
156
- def execute(self, code: str, paths: np.ndarray = None) -> Dict[str, Any]:
157
- """Execute visualization code and return results"""
158
- if paths is not None:
159
- self.globals['paths'] = paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- result = {'figures': [], 'error': None}
 
 
 
162
  try:
163
- exec(code, self.globals, self.locals)
164
- buf = io.BytesIO()
165
- plt.gcf().savefig(buf, format='png', dpi=300, bbox_inches='tight')
166
- buf.seek(0)
167
- result['figures'].append(buf.getvalue())
168
- plt.close('all')
169
  except Exception as e:
170
- result['error'] = f"Visualization failed: {str(e)}"
171
- plt.close('all')
172
- return result
173
 
174
- class SalarySimulator:
175
- """Monte Carlo simulation for salary projections"""
176
 
177
- def __init__(self, years: int = 5, num_paths: int = 1000):
178
- self.years = years
179
- self.num_paths = num_paths
180
-
181
- def run_simulation(self, profile: Dict[str, float]) -> np.ndarray:
182
- """Generate salary growth paths using Monte Carlo simulation"""
183
- paths = np.zeros((self.num_paths, self.years + 1))
184
- paths[:, 0] = profile['current_salary']
185
 
186
- base_growth = 0.02 + (profile['industry_score'] * 0.04)
187
- skill_premium = 0.01 + (profile['skills_score'] * 0.02)
188
- exp_premium = 0.01 + (profile['experience_score'] * 0.02)
189
- edu_premium = 0.005 + (profile['education_score'] * 0.015)
190
- location_premium = 0.01 + (profile['location_score'] * 0.02)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- volatility = 0.05 + (profile['industry_score'] * 0.05)
193
- disruption_chance = 0.1
194
- disruption_impact = 0.2
 
 
 
 
 
 
 
195
 
196
- for path in range(self.num_paths):
197
- salary = paths[path, 0]
198
- for year in range(1, self.years + 1):
199
- growth = base_growth + skill_premium + exp_premium + edu_premium + location_premium
200
- growth += np.random.normal(0, volatility)
201
- if np.random.random() < disruption_chance:
202
- impact = disruption_impact * np.random.random()
203
- growth += impact if np.random.random() < 0.7 else -impact
204
- growth = max(min(growth, 0.25), -0.1)
205
- salary *= (1 + growth)
206
- paths[path, year] = salary
207
- return paths
208
 
209
- class CareerAdvisor:
210
- """Main career advisor system"""
211
-
212
- def __init__(self, years: int = 5, num_paths: int = 1000):
213
- self.chat_history = []
214
- self.simulator = SalarySimulator(years, num_paths)
215
- self.code_env = CodeEnvironment()
216
 
217
- def reset(self):
218
- """Reset conversation state"""
219
- self.chat_history = []
 
220
 
221
- def chat(self, message: str, api_key: str) -> str:
222
- """Process user message and generate response"""
223
- if not api_key.strip().startswith("sk-"):
224
- return "Please enter a valid OpenAI API key starting with 'sk-'."
225
  try:
226
- messages = [{"role": "system", "content": CONVERSATION_PROMPT}] + \
227
- self.chat_history + [{"role": "user", "content": message}]
228
- response = completion(model="gpt-4o-mini", messages=messages, api_key=api_key)
229
- self.chat_history.extend([
230
- {"role": "user", "content": message},
231
- {"role": "assistant", "content": response.choices[0].message.content}
232
- ])
233
- return response.choices[0].message.content
234
  except Exception as e:
235
- return f"Chat error: {str(e)}. Please check your API key or try again."
 
236
 
237
- def generate_analysis(self, api_key: str) -> Tuple[str, bytes]:
238
- """Generate complete analysis with visualization"""
239
- if not self.chat_history:
240
- return "Please chat about your career first to generate an analysis.", None
241
  try:
242
- profile = self._extract_profile(api_key)
243
- paths = self.simulator.run_simulation(profile)
244
-
245
- viz_code = """
246
- import matplotlib.pyplot as plt
247
- import numpy as np
248
- plt.style.use('dark_background')
249
- fig = plt.figure(figsize=(12, 16))
250
- ax1 = plt.subplot2grid((2, 1), (0, 0))
251
- for path in paths[::20]:
252
- ax1.plot(range(paths.shape[1]), path, color='#4a90e2', alpha=0.1, linewidth=1)
253
- percentiles = [10, 25, 50, 75, 90]
254
- colors = ['#ff9999', '#ffcc99', '#ffffff', '#ffcc99', '#ff9999']
255
- labels = ['10th', '25th', 'Median', '75th', '90th']
256
- for p, color, label in zip(percentiles, colors, labels):
257
- line = np.percentile(paths, p, axis=0)
258
- ax1.plot(range(paths.shape[1]), line, color=color, linewidth=2, label=f'{label} percentile')
259
- ax1.set_title('Salary Growth Projections\n', fontsize=16, pad=20)
260
- ax1.set_xlabel('Years', fontsize=12)
261
- ax1.set_ylabel('Salary ($)', fontsize=12)
262
- ax1.grid(True, alpha=0.2)
263
- ax1.legend(fontsize=10)
264
- ax1.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'${x:,.0f}'))
265
- ax1.set_xticks(range(paths.shape[1]))
266
- ax1.set_xticklabels(['Current'] + [f'Year {i+1}' for i in range(paths.shape[1]-1)])
267
- ax2 = plt.subplot2grid((2, 1), (1, 0))
268
- final_salaries = paths[:, -1]
269
- ax2.hist(final_salaries, bins=50, color='#4a90e2', alpha=0.7)
270
- ax2.set_title('Final Salary Distribution\n', fontsize=16, pad=20)
271
- ax2.set_xlabel('Salary ($)', fontsize=12)
272
- ax2.set_ylabel('Frequency', fontsize=12)
273
- ax2.grid(True, alpha=0.2)
274
- ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'${x:,.0f}'))
275
- for p, color in zip(percentiles, colors):
276
- value = np.percentile(final_salaries, p)
277
- ax2.axvline(x=value, color=color, linestyle='--', alpha=0.5)
278
- plt.tight_layout(pad=4)
279
- """
280
- viz_result = self.code_env.execute(viz_code, paths)
281
- if viz_result['error']:
282
- return f"Analysis generated, but {viz_result['error']}", None
283
- summary = self._generate_summary(profile, paths)
284
- return summary, viz_result['figures'][0]
285
  except Exception as e:
286
- return f"Analysis error: {str(e)}. Please ensure sufficient chat history.", None
 
287
 
288
- def _extract_profile(self, api_key: str) -> Dict[str, float]:
289
- """Extract profile scores from conversation"""
290
- conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in self.chat_history])
291
- messages = [
292
- {"role": "system", "content": EXTRACTION_PROMPT},
293
- {"role": "user", "content": f"Extract profile from:\n{conversation}"}
294
- ]
295
- response = completion(
296
- model="gpt-4o-mini",
297
- messages=messages,
298
- api_key=api_key,
299
- response_format={"type": "json_object"}
300
- )
301
- return json.loads(response.choices[0].message.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
- def _generate_summary(self, profile: Dict[str, float], paths: np.ndarray) -> str:
304
- """Generate analysis summary"""
305
- final_salaries = paths[:, -1]
306
- initial_salary = paths[0, 0]
307
- cagr = (np.median(final_salaries) / initial_salary) ** (1/self.simulator.years) - 1
 
 
 
 
 
 
 
 
308
 
309
- return f"""
310
- Career Profile Analysis
311
- ======================
312
- Current Situation:
313
- β€’ Salary: ${profile['current_salary']:,.2f}
314
- β€’ Industry Position: {profile['industry_score']:.2f}/1.0
315
- β€’ Experience Level: {profile['experience_score']:.2f}/1.0
316
- β€’ Education Rating: {profile['education_score']:.2f}/1.0
317
- β€’ Skills Assessment: {profile['skills_score']:.2f}/1.0
318
- β€’ Location Impact: {profile['location_score']:.2f}/1.0
319
- {self.simulator.years}-Year Projection:
320
- β€’ Conservative (25th percentile): ${np.percentile(final_salaries, 25):,.2f}
321
- β€’ Most Likely (Median): ${np.percentile(final_salaries, 50):,.2f}
322
- β€’ Optimistic (75th percentile): ${np.percentile(final_salaries, 75):,.2f}
323
- β€’ Expected Annual Growth: {cagr*100:.1f}%
324
- Key Insights:
325
- β€’ Your profile suggests {cagr*100:.1f}% annual growth potential
326
- β€’ {profile['industry_score']:.2f} industry score indicates {'strong' if profile['industry_score'] > 0.7 else 'moderate' if profile['industry_score'] > 0.5 else 'challenging'} growth environment
327
- β€’ Skills rating of {profile['skills_score']:.2f} suggests {'excellent' if profile['skills_score'] > 0.7 else 'good' if profile['skills_score'] > 0.5 else 'potential for'} career advancement
328
- β€’ Location score {profile['location_score']:.2f} {'enhances' if profile['location_score'] > 0.7 else 'supports' if profile['location_score'] > 0.5 else 'may limit'} opportunities
329
- Based on {self.simulator.num_paths:,} simulated career paths
330
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- def create_interface():
333
- """Create Gradio interface with configurable simulation parameters"""
334
- advisor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
- def init_advisor(years: int, num_paths: int):
337
- nonlocal advisor
338
- advisor = CareerAdvisor(years=max(1, years), num_paths=max(100, num_paths))
339
- advisor.reset()
340
 
341
- def user_message(message: str, history: List, api_key: str) -> Tuple[str, List]:
342
- if not message.strip():
343
- return "", history
344
- if not advisor:
345
- return "Please set simulation parameters first.", history
346
- response = advisor.chat(message, api_key)
347
- return "", history + [(message, response)]
348
 
349
- def generate_analysis(api_key: str, history: List) -> Tuple[str, gr.Image]:
350
- if not advisor or not history:
351
- return "Please chat about your career and set parameters first.", None
352
- summary, figure_data = advisor.generate_analysis(api_key)
353
- return summary, figure_data if figure_data else None
354
 
355
- with gr.Blocks(title="Monte Carlo Salary Prediction", theme=gr.themes.Soft()) as demo:
356
- gr.Markdown("# πŸ’° Monte Carlo Simulation of Salary Prediction\nChat about your career to see your growth potential!")
 
357
 
358
- with gr.Row():
359
- api_key = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API key", type="password")
360
- years = gr.Number(label="Simulation Years", value=5, minimum=1, step=1)
361
- num_paths = gr.Number(label="Number of Paths", value=1000, minimum=100, step=100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
- chatbot = gr.Chatbot(value=[], height=400, show_copy_button=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
- with gr.Row():
366
- msg = gr.Textbox(label="Your message", placeholder="Tell me about your career...", lines=2)
367
- send = gr.Button("Send", variant="primary", scale=0)
 
 
 
368
 
369
- analyze = gr.Button("Generate Analysis", variant="secondary", size="lg")
 
 
 
 
370
 
371
- with gr.Row():
372
- analysis = gr.Textbox(label="Analysis", lines=10, show_copy_button=True)
373
- plot = gr.Image(label="Projections", show_download_button=True, height=600)
 
 
374
 
375
- demo.load(lambda y, n: init_advisor(y, n), inputs=[years, num_paths], outputs=None)
376
- msg.submit(user_message, inputs=[msg, chatbot, api_key], outputs=[msg, chatbot])
377
- send.click(user_message, inputs=[msg, chatbot, api_key], outputs=[msg, chatbot])
378
- analyze.click(generate_analysis, inputs=[api_key, chatbot], outputs=[analysis, plot])
 
379
 
380
- return demo
 
381
 
382
  if __name__ == "__main__":
383
- demo = create_interface()
384
- demo.launch()
 
 
 
1
  import os
2
+ import sys
 
 
 
 
 
3
  import logging
4
+ from pathlib import Path
5
+ import json
6
+ import hashlib
7
+ from datetime import datetime
8
+ import threading
9
+ import queue
10
+ from typing import List, Dict, Any, Tuple, Optional
11
 
12
  # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
  logger = logging.getLogger(__name__)
15
 
16
+ # Importing necessary libraries
17
+ import torch
18
+ import numpy as np
19
+ from sentence_transformers import SentenceTransformer
20
+ import chromadb
21
+ from chromadb.utils import embedding_functions
22
+ import gradio as gr
23
+ from openai import OpenAI
24
+ import google.generativeai as genai
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Configuration class
27
+ class Config:
28
+ """Configuration for vector store and RAG"""
29
+ def __init__(self,
30
+ local_dir: str = "./chroma_data",
31
+ batch_size: int = 20,
32
+ max_workers: int = 4,
33
+ embedding_model: str = "all-MiniLM-L6-v2",
34
+ collection_name: str = "markdown_docs"):
35
+ self.local_dir = local_dir
36
+ self.batch_size = batch_size
37
+ self.max_workers = max_workers
38
+ self.checkpoint_file = Path(local_dir) / "checkpoint.json"
39
+ self.embedding_model = embedding_model
40
+ self.collection_name = collection_name
41
+
42
+ # Create local directory for checkpoints and Chroma
43
+ Path(local_dir).mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Embedding engine
46
+ class EmbeddingEngine:
47
+ """Handle embeddings with a lightweight model"""
 
 
 
48
 
49
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
50
+ # Use GPU if available
51
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Using device: {self.device}")
53
+
54
+ # Try multiple model options in order of preference
55
+ model_options = [
56
+ model_name,
57
+ "all-MiniLM-L6-v2",
58
+ "paraphrase-MiniLM-L3-v2",
59
+ "all-mpnet-base-v2" # Higher quality but larger model
60
+ ]
61
+
62
+ self.model = None
63
+
64
+ # Try each model in order until one works
65
+ for model_option in model_options:
66
+ try:
67
+ logger.info(f"Attempting to load model: {model_option}")
68
+ self.model = SentenceTransformer(model_option)
69
+
70
+ # Move model to device
71
+ self.model.to(self.device)
72
+
73
+ logger.info(f"Successfully loaded model: {model_option}")
74
+ self.model_name = model_option
75
+ self.vector_size = self.model.get_sentence_embedding_dimension()
76
+ break
77
+
78
+ except Exception as e:
79
+ logger.warning(f"Failed to load model {model_option}: {str(e)}")
80
+
81
+ if self.model is None:
82
+ logger.error("Failed to load any embedding model. Exiting.")
83
+ sys.exit(1)
84
+
85
+ def encode(self, text, batch_size=32):
86
+ """Get embedding for a text or list of texts"""
87
+ # Handle single text
88
+ if isinstance(text, str):
89
+ texts = [text]
90
+ else:
91
+ texts = text
92
 
93
+ # Truncate texts if necessary to avoid tokenization issues
94
+ truncated_texts = [t[:50000] if len(t) > 50000 else t for t in texts]
95
+
96
+ # Generate embeddings
97
  try:
98
+ embeddings = self.model.encode(truncated_texts, batch_size=batch_size,
99
+ show_progress_bar=False, convert_to_numpy=True)
100
+ return embeddings
 
 
 
101
  except Exception as e:
102
+ logger.error(f"Error generating embeddings: {e}")
103
+ # Return zero embeddings as fallback
104
+ return np.zeros((len(truncated_texts), self.vector_size))
105
 
106
+ class VectorStoreManager:
107
+ """Manage Chroma vector store operations - upload, query, etc."""
108
 
109
+ def __init__(self, config: Config):
110
+ self.config = config
111
+
112
+ # Initialize Chroma client (local persistence)
113
+ logger.info(f"Initializing Chroma at {config.local_dir}")
114
+ self.client = chromadb.PersistentClient(path=config.local_dir)
 
 
115
 
116
+ # Get or create collection
117
+ try:
118
+ # Initialize embedding model
119
+ logger.info("Loading embedding model...")
120
+ self.embedding_engine = EmbeddingEngine(config.embedding_model)
121
+ logger.info(f"Using model: {self.embedding_engine.model_name}")
122
+
123
+ # Create embedding function
124
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
125
+ model_name=self.embedding_engine.model_name
126
+ )
127
+
128
+ # Try to get existing collection
129
+ try:
130
+ self.collection = self.client.get_collection(
131
+ name=config.collection_name,
132
+ embedding_function=sentence_transformer_ef
133
+ )
134
+ logger.info(f"Using existing collection: {config.collection_name}")
135
+ except:
136
+ # Create new collection if it doesn't exist
137
+ self.collection = self.client.create_collection(
138
+ name=config.collection_name,
139
+ embedding_function=sentence_transformer_ef,
140
+ metadata={"hnsw:space": "cosine"}
141
+ )
142
+ logger.info(f"Created new collection: {config.collection_name}")
143
+
144
+ except Exception as e:
145
+ logger.error(f"Error initializing Chroma collection: {e}")
146
+ sys.exit(1)
147
+
148
+ def query(self, query_text: str, n_results: int = 5) -> List[Dict]:
149
+ """
150
+ Query the vector store with a text query
151
+ """
152
+ try:
153
+ # Query the collection
154
+ search_results = self.collection.query(
155
+ query_texts=[query_text],
156
+ n_results=n_results,
157
+ include=["documents", "metadatas", "distances"]
158
+ )
159
+
160
+ # Format results
161
+ results = []
162
+ if search_results["documents"] and len(search_results["documents"][0]) > 0:
163
+ for i in range(len(search_results["documents"][0])):
164
+ results.append({
165
+ 'document': search_results["documents"][0][i],
166
+ 'metadata': search_results["metadatas"][0][i],
167
+ 'score': 1.0 - search_results["distances"][0][i] # Convert distance to similarity
168
+ })
169
+
170
+ return results
171
+ except Exception as e:
172
+ logger.error(f"Error querying collection: {e}")
173
+ return []
174
+
175
+ def get_statistics(self) -> Dict[str, Any]:
176
+ """Get statistics about the vector store"""
177
+ stats = {}
178
 
179
+ try:
180
+ # Get collection count
181
+ collection_info = self.collection.count()
182
+ stats['total_documents'] = collection_info
183
+
184
+ # Estimate unique files - with no chunking, each document is a file
185
+ stats['unique_files'] = collection_info
186
+ except Exception as e:
187
+ logger.error(f"Error getting statistics: {e}")
188
+ stats['error'] = str(e)
189
 
190
+ return stats
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ class RAGSystem:
193
+ """Retrieval-Augmented Generation with multiple LLM providers"""
 
 
 
 
 
194
 
195
+ def __init__(self, vector_store: VectorStoreManager):
196
+ self.vector_store = vector_store
197
+ self.openai_client = None
198
+ self.gemini_configured = False
199
 
200
+ def setup_openai(self, api_key: str):
201
+ """Set up OpenAI client with API key"""
 
 
202
  try:
203
+ self.openai_client = OpenAI(api_key=api_key)
204
+ return True
 
 
 
 
 
 
205
  except Exception as e:
206
+ logger.error(f"Error initializing OpenAI client: {e}")
207
+ return False
208
 
209
+ def setup_gemini(self, api_key: str):
210
+ """Set up Gemini with API key"""
 
 
211
  try:
212
+ genai.configure(api_key=api_key)
213
+ self.gemini_configured = True
214
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  except Exception as e:
216
+ logger.error(f"Error configuring Gemini: {e}")
217
+ return False
218
 
219
+ def format_context(self, documents: List[Dict]) -> str:
220
+ """Format retrieved documents into context for the LLM"""
221
+ if not documents:
222
+ return "No relevant documents found."
223
+
224
+ context_parts = []
225
+ for i, doc in enumerate(documents):
226
+ metadata = doc['metadata']
227
+ title = metadata.get('title', metadata.get('filename', 'Unknown document'))
228
+
229
+ # For readability, limit length of context document
230
+ doc_text = doc['document']
231
+ if len(doc_text) > 10000: # Limit long documents in context
232
+ doc_text = doc_text[:10000] + "... [Document truncated for context]"
233
+
234
+ context_parts.append(f"Document {i+1} - {title}:\n{doc_text}\n")
235
+
236
+ return "\n".join(context_parts)
237
+
238
+ def generate_response_openai(self, query: str, context: str) -> str:
239
+ """Generate a response using OpenAI model with context"""
240
+ if not self.openai_client:
241
+ return "Error: OpenAI API key not configured. Please enter an API key in the settings tab."
242
+
243
+ system_prompt = """
244
+ You are a helpful assistant that answers questions based on the context provided.
245
+ Use the information from the context to answer the user's question.
246
+ If the context doesn't contain the information needed, say so clearly.
247
+ Always cite the specific sections from the context that you used in your answer.
248
+ """
249
+
250
+ try:
251
+ response = self.openai_client.chat.completions.create(
252
+ model="gpt-4o-mini", # Use GPT-4o mini
253
+ messages=[
254
+ {"role": "system", "content": system_prompt},
255
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {query}"}
256
+ ],
257
+ temperature=0.3, # Lower temperature for more factual responses
258
+ max_tokens=1000,
259
+ )
260
+ return response.choices[0].message.content
261
+ except Exception as e:
262
+ logger.error(f"Error generating response with OpenAI: {e}")
263
+ return f"Error generating response with OpenAI: {str(e)}"
264
 
265
+ def generate_response_gemini(self, query: str, context: str) -> str:
266
+ """Generate a response using Gemini with context"""
267
+ if not self.gemini_configured:
268
+ return "Error: Google AI API key not configured. Please enter an API key in the settings tab."
269
+
270
+ prompt = f"""
271
+ You are a helpful assistant that answers questions based on the context provided.
272
+ Use the information from the context to answer the user's question.
273
+ If the context doesn't contain the information needed, say so clearly.
274
+ Always cite the specific sections from the context that you used in your answer.
275
+
276
+ Context:
277
+ {context}
278
 
279
+ Question: {query}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  """
281
+
282
+ try:
283
+ model = genai.GenerativeModel('gemini-1.5-flash')
284
+ response = model.generate_content(prompt)
285
+ return response.text
286
+ except Exception as e:
287
+ logger.error(f"Error generating response with Gemini: {e}")
288
+ return f"Error generating response with Gemini: {str(e)}"
289
+
290
+ def query_and_generate(self, query: str, n_results: int = 5, model: str = "openai") -> str:
291
+ """Retrieve relevant documents and generate a response using the specified model"""
292
+ # Query vector store
293
+ documents = self.vector_store.query(query, n_results=n_results)
294
+
295
+ if not documents:
296
+ return "No relevant documents found to answer your question."
297
+
298
+ # Format context
299
+ context = self.format_context(documents)
300
+
301
+ # Generate response with the appropriate model
302
+ if model == "openai":
303
+ return self.generate_response_openai(query, context)
304
+ elif model == "gemini":
305
+ return self.generate_response_gemini(query, context)
306
+ else:
307
+ return f"Unknown model: {model}"
308
+
309
+ def rag_chat(query, n_results, model_choice, rag_system):
310
+ """Function to handle RAG chat queries"""
311
+ return rag_system.query_and_generate(query, n_results=int(n_results), model=model_choice)
312
+
313
+ def simple_query(query, n_results, vector_store):
314
+ """Function to handle simple vector store queries"""
315
+ results = vector_store.query(query, n_results=int(n_results))
316
+
317
+ # Format results for display
318
+ formatted = []
319
+ for i, res in enumerate(results):
320
+ metadata = res['metadata']
321
+ title = metadata.get('title', metadata.get('filename', 'Unknown'))
322
+ # Limit preview text for display
323
+ preview = res['document'][:800] + '...' if len(res['document']) > 800 else res['document']
324
+ formatted.append(f"**Result {i+1}** (Similarity: {res['score']:.2f})\n\n"
325
+ f"**Source:** {title}\n\n"
326
+ f"**Content:**\n{preview}\n\n"
327
+ f"---\n")
328
+
329
+ return "\n".join(formatted) if formatted else "No results found."
330
+
331
+ def get_db_stats(vector_store):
332
+ """Function to get vector store statistics"""
333
+ stats = vector_store.get_statistics()
334
+ return (f"Total documents: {stats.get('total_documents', 0)}\n"
335
+ f"Unique files: {stats.get('unique_files', 0)}")
336
+
337
+ def update_api_keys(openai_key, gemini_key, rag_system):
338
+ """Update API keys for the RAG system"""
339
+ success_msg = []
340
+
341
+ if openai_key:
342
+ if rag_system.setup_openai(openai_key):
343
+ success_msg.append("βœ… OpenAI API key configured successfully")
344
+ else:
345
+ success_msg.append("❌ Failed to configure OpenAI API key")
346
+
347
+ if gemini_key:
348
+ if rag_system.setup_gemini(gemini_key):
349
+ success_msg.append("βœ… Google AI API key configured successfully")
350
+ else:
351
+ success_msg.append("❌ Failed to configure Google AI API key")
352
+
353
+ if not success_msg:
354
+ return "Please enter at least one API key"
355
+
356
+ return "\n".join(success_msg)
357
 
358
+ # Main function to run the application
359
+ def main():
360
+ # Set up paths for existing Chroma database
361
+ chroma_dir = Path("./chroma_data")
362
+
363
+ # Initialize the system
364
+ config = Config(
365
+ local_dir=str(chroma_dir),
366
+ collection_name="markdown_docs"
367
+ )
368
+
369
+ # Initialize vector store manager with existing collection
370
+ vector_store = VectorStoreManager(config)
371
+
372
+ # Initialize RAG system without API keys initially
373
+ rag_system = RAGSystem(vector_store)
374
 
375
+ # Define Gradio app
376
+ def rag_chat_wrapper(query, n_results, model_choice):
377
+ return rag_chat(query, n_results, model_choice, rag_system)
 
378
 
379
+ def simple_query_wrapper(query, n_results):
380
+ return simple_query(query, n_results, vector_store)
 
 
 
 
 
381
 
382
+ def update_api_keys_wrapper(openai_key, gemini_key):
383
+ return update_api_keys(openai_key, gemini_key, rag_system)
 
 
 
384
 
385
+ # Create the Gradio interface
386
+ with gr.Blocks(title="Markdown RAG System") as app:
387
+ gr.Markdown("# RAG System with Multiple LLM Providers")
388
 
389
+ with gr.Tab("Chat with Documents"):
390
+ with gr.Row():
391
+ with gr.Column(scale=3):
392
+ query_input = gr.Textbox(label="Question", placeholder="Ask a question about your documents...")
393
+ num_results = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of documents to retrieve")
394
+ model_choice = gr.Radio(
395
+ choices=["openai", "gemini"],
396
+ value="openai",
397
+ label="Choose LLM Provider",
398
+ info="Select which model to use for generating answers"
399
+ )
400
+ query_button = gr.Button("Ask", variant="primary")
401
+
402
+ with gr.Column(scale=7):
403
+ response_output = gr.Markdown(label="Response")
404
+
405
+ # Database stats
406
+ stats_display = gr.Textbox(label="Database Statistics", value=get_db_stats(vector_store))
407
+ refresh_button = gr.Button("Refresh Statistics")
408
+
409
+ with gr.Tab("Document Search"):
410
+ search_input = gr.Textbox(label="Search Query", placeholder="Search your documents...")
411
+ search_num = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
412
+ search_button = gr.Button("Search", variant="primary")
413
+ search_output = gr.Markdown(label="Search Results")
414
 
415
+ with gr.Tab("Settings"):
416
+ gr.Markdown("""
417
+ ## API Keys Configuration
418
+
419
+ This application can use either OpenAI's GPT-4o-mini or Google's Gemini 1.5 Flash for generating responses.
420
+ You need to provide at least one API key to use the chat functionality.
421
+ """)
422
+
423
+ openai_key_input = gr.Textbox(
424
+ label="OpenAI API Key",
425
+ placeholder="Enter your OpenAI API key here...",
426
+ type="password"
427
+ )
428
+
429
+ gemini_key_input = gr.Textbox(
430
+ label="Google AI API Key",
431
+ placeholder="Enter your Google AI API key here...",
432
+ type="password"
433
+ )
434
+
435
+ save_keys_button = gr.Button("Save API Keys", variant="primary")
436
+ api_status = gr.Markdown("")
437
 
438
+ # Set up events
439
+ query_button.click(
440
+ fn=rag_chat_wrapper,
441
+ inputs=[query_input, num_results, model_choice],
442
+ outputs=response_output
443
+ )
444
 
445
+ refresh_button.click(
446
+ fn=lambda: get_db_stats(vector_store),
447
+ inputs=None,
448
+ outputs=stats_display
449
+ )
450
 
451
+ search_button.click(
452
+ fn=simple_query_wrapper,
453
+ inputs=[search_input, search_num],
454
+ outputs=search_output
455
+ )
456
 
457
+ save_keys_button.click(
458
+ fn=update_api_keys_wrapper,
459
+ inputs=[openai_key_input, gemini_key_input],
460
+ outputs=api_status
461
+ )
462
 
463
+ # Launch the interface
464
+ app.launch()
465
 
466
  if __name__ == "__main__":
467
+ main()