AnilNiraula commited on
Commit
1d02fc4
·
verified ·
1 Parent(s): ecb3e47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -371
app.py CHANGED
@@ -11,23 +11,9 @@ from huggingface_hub import hf_hub_download, login
11
  import logging
12
  import pandas as pd
13
  import torch
14
- import yfinance as yf
15
- from datetime import datetime, timedelta
16
- from math import sqrt
17
- import time
18
- import base64
19
- import io
20
- import numpy as np
21
- try:
22
- import scipy.optimize as opt
23
- except ModuleNotFoundError:
24
- subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
25
- import scipy.optimize as opt
26
-
27
  # Set up logging
28
  logging.basicConfig(level=logging.INFO)
29
  logger = logging.getLogger(__name__)
30
-
31
  # Install llama-cpp-python with appropriate backend
32
  try:
33
  from llama_cpp import Llama
@@ -39,40 +25,34 @@ except ModuleNotFoundError:
39
  else:
40
  logger.info("Installing llama-cpp-python without additional flags.")
41
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
42
- from llama_cpp import Llama
43
-
44
  # Install yfinance if not present (for CAGR calculations)
45
  try:
46
  import yfinance as yf
47
  except ModuleNotFoundError:
48
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
49
- import yfinance as yf
50
-
51
  # Import pandas for handling DataFrame column structures
52
  import pandas as pd
53
-
54
  # Additional imports for visualization and file handling
55
  try:
56
  import matplotlib.pyplot as plt
57
  from PIL import Image
58
  import io
59
  except ModuleNotFoundError:
60
- subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow", "numpy"])
61
- import matplotlib.pyplot as plt
62
- from PIL import Image
63
- import io
64
- import numpy as np
65
-
66
  MAX_MAX_NEW_TOKENS = 512
67
  DEFAULT_MAX_NEW_TOKENS = 512
68
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
69
-
70
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
71
- This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.<p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
72
-
73
- LICENSE = """<p/>---
 
74
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
75
-
76
  # Load the model (skip fine-tuning for faster startup)
77
  try:
78
  model_path = hf_hub_download(
@@ -83,10 +63,10 @@ try:
83
  llm = Llama(
84
  model_path=model_path,
85
  n_ctx=1024,
86
- n_batch=1024, # Increased for faster processing
87
  n_threads=multiprocessing.cpu_count(),
88
  n_gpu_layers=n_gpu_layers,
89
- chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
90
  )
91
  logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
92
  # Warm up the model for faster initial inference
@@ -95,361 +75,326 @@ try:
95
  except Exception as e:
96
  logger.error(f"Error loading model: {str(e)}")
97
  raise
98
-
99
  # Register explicit close for llm to avoid destructor error
100
  atexit.register(llm.close)
101
-
102
  DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
103
  Always respond exclusively in English. Use bullet points for clarity.
104
- Do not substitute or alter stock symbols provided in the user's query. Always use the exact tickers mentioned.
105
  Example:
106
- User: average return for AAPL between 2010 and 2020
107
  Assistant:
108
- - AAPL CAGR (2010-2020): ~27.24%
109
  - Represents average annual return with compounding
110
- - Past performance is not indicative of future results."""
111
-
112
- logs = []
113
-
114
- # Function to calculate CAGR using yfinance
115
- def calculate_cagr(ticker, start_date, end_date):
116
- try:
117
- data = yf.download(ticker, start=start_date, end=end_date)
118
- if data.empty:
119
- return None
120
- start_price = data['Adj Close'].iloc[0]
121
- end_price = data['Adj Close'].iloc[-1]
122
- num_years = (data.index[-1] - data.index[0]).days / 365.25
123
- cagr = (end_price / start_price) ** (1 / num_years) - 1
124
- return cagr * 100 # Return as percentage
125
- except Exception as e:
126
- logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
127
- return None
128
-
129
- # Function to calculate risk metrics using yfinance
130
- def calculate_risk_metrics(ticker, years=5):
131
- try:
132
- end_date = datetime.now().strftime('%Y-%m-%d')
133
- start_date = (datetime.now() - timedelta(days=365 * years)).strftime('%Y-%m-%d')
134
- data = yf.download(ticker, start=start_date, end=end_date)
135
- if data.empty:
136
- return None, None
137
- returns = data['Adj Close'].pct_change().dropna()
138
- volatility = returns.std() * sqrt(252) * 100 # Annualized volatility in percent
139
- mean_return = returns.mean() * 252 # Annualized mean return
140
- risk_free_rate = 0.02 # Assumed risk-free rate (e.g., 2%)
141
- sharpe = (mean_return - risk_free_rate) / (volatility / 100) # Sharpe ratio
142
- return volatility, sharpe
143
- except Exception as e:
144
- logger.error(f"Error calculating risk metrics for {ticker}: {str(e)}")
145
- return None, None
146
-
147
- # Function for inline plot
148
- def generate_plot(ticker, period='5y'):
149
- try:
150
- data = yf.download(ticker, period=period)
151
- if data.empty:
152
- return "Unable to fetch data for plotting."
153
- plt.figure(figsize=(10, 5))
154
- plt.plot(data['Adj Close'], label='Adjusted Close')
155
- plt.title(f'{ticker} Price History ({period})')
156
- plt.xlabel('Date')
157
- plt.ylabel('Price (USD)')
158
- plt.legend()
159
- plt.grid(True)
160
- buf = io.BytesIO()
161
- plt.savefig(buf, format='png', bbox_inches='tight')
162
- buf.seek(0)
163
- b64 = base64.b64encode(buf.read()).decode('utf-8')
164
- plt.close()
165
- return f"![{ticker} Price Chart](data:image/png;base64,{b64})"
166
- except Exception as e:
167
- logger.error(f"Error generating plot for {ticker}: {str(e)}")
168
- return "Error generating plot."
169
-
170
- # Function for portfolio optimization using scipy
171
- def portfolio_optimization(tickers, target_return=None):
172
- try:
173
- data = yf.download(tickers, period='5y')['Adj Close']
174
- returns = data.pct_change().dropna()
175
- mean_returns = returns.mean() * 252
176
- cov_matrix = returns.cov() * 252
177
- num_assets = len(tickers)
178
-
179
- def portfolio_volatility(weights):
180
- return np.sqrt(np.dot(weights.T, np.dot(cov_matrix, weights)))
181
-
182
- constraints = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1})
183
- bounds = tuple((0, 1) for _ in range(num_assets))
184
- initial_guess = np.array(num_assets * [1. / num_assets])
185
-
186
- if target_return:
187
- # Maximize Sharpe or min vol for target return
188
- def objective(weights):
189
- ret = np.sum(mean_returns * weights)
190
- vol = portfolio_volatility(weights)
191
- return - (ret - 0.02) / vol if vol != 0 else np.inf # Neg Sharpe
192
- cons = [{'type': 'eq', 'fun': lambda x: np.sum(x) - 1},
193
- {'type': 'eq', 'fun': lambda x: np.sum(mean_returns * x) - target_return}]
194
- result = opt.minimize(objective, initial_guess, method='SLSQP', bounds=bounds, constraints=cons)
195
- else:
196
- # Minimize volatility
197
- result = opt.minimize(portfolio_volatility, initial_guess, method='SLSQP',
198
- bounds=bounds, constraints=constraints)
199
-
200
- if result.success:
201
- weights = dict(zip(tickers, result.x))
202
- return weights
203
- else:
204
- return {ticker: 1/len(tickers) for ticker in tickers} # Fallback equal weights
205
- except Exception as e:
206
- logger.error(f"Error in portfolio optimization: {str(e)}")
207
- return {ticker: 1/len(tickers) for ticker in tickers}
208
-
209
- # Assuming the generate function handles the chat logic (extended to include risk comparison)
210
  def generate(
211
  message: str,
212
- history: list[tuple[str, str]],
213
- system_prompt: str,
214
- max_new_tokens: int,
215
- temperature: float,
216
- top_p: float,
217
- top_k: int,
218
- logs_state: list
219
- ) -> tuple[Iterator[str], list]:
220
- start_time = time.time()
221
- if not system_prompt:
222
- system_prompt = DEFAULT_SYSTEM_PROMPT
223
-
224
- full_response = ""
225
- # Detect CAGR query
226
- cagr_match = re.search(r'average return for (\w+) between (\d{4}) and (\d{4})', message.lower())
227
- if cagr_match:
228
- ticker = cagr_match.group(1).upper()
229
- start_year = cagr_match.group(2)
230
- end_year = cagr_match.group(3)
231
- start_date = f"{start_year}-01-01"
232
- end_date = f"{end_year}-12-31"
233
- cagr = calculate_cagr(ticker, start_date, end_date)
234
- if cagr is not None:
235
- response = f"- {ticker} CAGR ({start_year}-{end_year}): ~{cagr:.2f}%\n- Represents average annual return with compounding\n- Past performance is not indicative of future results.\n- Consult a financial advisor for personalized advice."
236
- yield response
237
- full_response = response
238
- else:
239
- response = "Unable to calculate CAGR for the specified period."
240
- yield response
241
- full_response = response
242
- end_time = time.time()
243
- logs_state.append({
244
- 'timestamp': datetime.now().isoformat(),
245
- 'query': message,
246
- 'response': full_response,
247
- 'response_length': len(full_response.split()),
248
- 'generation_time': end_time - start_time,
249
- 'token_efficiency': len(full_response.split()) / max_new_tokens
250
- })
251
- return iter([]), logs_state # No more yield
252
-
253
- # Detect risk comparison query
254
- risk_match = re.search(r'which stock is riskier (\w+) or (\w+)', message.lower())
255
- if risk_match:
256
- ticker1 = risk_match.group(1).upper()
257
- ticker2 = risk_match.group(2).upper()
258
- vol1, sharpe1 = calculate_risk_metrics(ticker1)
259
- vol2, sharpe2 = calculate_risk_metrics(ticker2)
260
- if vol1 is None or vol2 is None:
261
- response = "Unable to fetch risk metrics for one or both tickers."
262
- yield response
263
- full_response = response
264
- else:
265
- if vol1 > vol2:
266
- riskier = ticker1
267
- less_risky = ticker2
268
- higher_vol = vol1
269
- lower_vol = vol2
270
- riskier_sharpe = sharpe1
271
- less_sharpe = sharpe2
272
- else:
273
- riskier = ticker2
274
- less_risky = ticker1
275
- higher_vol = vol2
276
- lower_vol = vol1
277
- riskier_sharpe = sharpe2
278
- less_sharpe = sharpe1
279
- response = f"- {riskier} is riskier compared to {less_risky}.\n- It has a higher annualized standard deviation ({higher_vol:.2f}% vs {lower_vol:.2f}%) and a lower Sharpe ratio ({riskier_sharpe:.2f} vs {less_sharpe:.2f}), indicating greater volatility and potentially lower risk-adjusted returns.\n- Calculations based on the past 5 years of data.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
280
  yield response
281
- full_response = response
282
- end_time = time.time()
283
- logs_state.append({
284
- 'timestamp': datetime.now().isoformat(),
285
- 'query': message,
286
- 'response': full_response,
287
- 'response_length': len(full_response.split()),
288
- 'generation_time': end_time - start_time,
289
- 'token_efficiency': len(full_response.split()) / max_new_tokens
290
- })
291
- return iter([]), logs_state
292
-
293
- # Detect plot/chart query
294
- plot_match = re.search(r'(plot|chart)\s+(\w+)(?:\s+(historical|price|volatility))?', message.lower())
295
- if plot_match:
296
- ticker = plot_match.group(2).upper()
297
- plot_type = plot_match.group(3) if plot_match.group(3) else 'price'
298
- if plot_type == 'volatility':
299
- # Simple volatility plot (returns histogram)
 
300
  try:
301
- data = yf.download(ticker, period='1y')
302
- returns = data['Adj Close'].pct_change().dropna()
303
- plt.figure(figsize=(10, 5))
304
- plt.hist(returns, bins=50, alpha=0.7)
305
- plt.title(f'{ticker} Daily Returns Distribution (1Y)')
306
- plt.xlabel('Return')
307
- plt.ylabel('Frequency')
308
- except:
309
- plot_type = 'price' # Fallback
310
- if plot_type != 'volatility':
311
- plot_md = generate_plot(ticker)
312
- response = f"Price chart for {ticker}:\n{plot_md}\n- This visualizes the historical adjusted close prices.\n- Past performance is not indicative of future results. Consult a financial advisor."
313
- yield response
314
- full_response = response
315
- else:
316
- # For volatility, similar
317
- buf = io.BytesIO()
318
- plt.savefig(buf, format='png', bbox_inches='tight')
319
- buf.seek(0)
320
- b64 = base64.b64encode(buf.read()).decode('utf-8')
321
- plt.close()
322
- plot_md = f"![{ticker} Volatility](data:image/png;base64,{b64})"
323
- response = f"Volatility chart for {ticker}:\n{plot_md}\n- Histogram of daily returns over the past year."
324
- yield response
325
- full_response = response
326
- end_time = time.time()
327
- logs_state.append({
328
- 'timestamp': datetime.now().isoformat(),
329
- 'query': message,
330
- 'response': full_response,
331
- 'response_length': len(full_response.split()),
332
- 'generation_time': end_time - start_time,
333
- 'token_efficiency': len(full_response.split()) / max_new_tokens
334
- })
335
- return iter([]), logs_state
336
-
337
- # Detect portfolio optimization query
338
- port_match = re.search(r'optimize\s+portfolio\s+for\s+([\w,\s]+)(?:\s+with\s+(risk|return)\s+tolerance\s+([\d.]+))?', message.lower())
339
- if port_match:
340
- tickers_str = port_match.group(1).strip()
341
- tickers = [t.strip().upper() for t in re.split(r'[,;]', tickers_str) if t.strip()]
342
- target = None
343
- if port_match.group(3):
344
- target = float(port_match.group(3))
345
- if port_match.group(2) == 'risk':
346
- # For risk tolerance, min vol with vol <= target (but simplify to min vol)
347
- pass # Use default min vol
348
- else:
349
- target_return = target
350
- weights = portfolio_optimization(tickers, target_return=target if 'return' in (port_match.group(2) or '') else None)
351
- df = pd.DataFrame(list(weights.items()), columns=['Ticker', 'Weight'])
352
- df['Weight'] = df['Weight'].round(4)
353
- table_md = df.to_markdown(index=False)
354
- response = f"- Suggested portfolio weights for {', '.join(tickers)}:\n{table_md}\n- Based on minimum variance optimization (or target return if specified).\n- Assumes 5-year historical data for means and covariances.\n- Past performance is not indicative of future results. Consult a financial advisor for personalized advice."
355
- yield response
356
- full_response = response
357
- end_time = time.time()
358
- logs_state.append({
359
- 'timestamp': datetime.now().isoformat(),
360
- 'query': message,
361
- 'response': full_response,
362
- 'response_length': len(full_response.split()),
363
- 'generation_time': end_time - start_time,
364
- 'token_efficiency': len(full_response.split()) / max_new_tokens
365
- })
366
- return iter([]), logs_state
367
-
368
- # For other queries, fall back to LLM generation
369
  conversation = [{"role": "system", "content": system_prompt}]
370
- for user, assistant in history:
371
- conversation.extend([
372
- {"role": "user", "content": user},
373
- {"role": "assistant", "content": assistant}
374
- ])
375
  conversation.append({"role": "user", "content": message})
376
-
377
- # Generate response using LLM (streamed)
378
- response = llm.create_chat_completion(
379
- messages=conversation,
380
- max_tokens=max_new_tokens,
381
- temperature=temperature,
382
- top_p=top_p,
383
- top_k=top_k,
384
- stream=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  )
386
- partial_text = ""
387
- for chunk in response:
388
- if "content" in chunk["choices"][0]["delta"]:
389
- partial_text += chunk["choices"][0]["delta"]["content"]
390
- yield partial_text
391
- full_response = partial_text
392
- end_time = time.time()
393
- logs_state.append({
394
- 'timestamp': datetime.now().isoformat(),
395
- 'query': message,
396
- 'response': full_response,
397
- 'response_length': len(full_response.split()),
398
- 'generation_time': end_time - start_time,
399
- 'token_efficiency': len(full_response.split()) / max_new_tokens
400
- })
401
- return iter([]), logs_state
402
-
403
- def update_logs(logs_state):
404
- if logs_state:
405
- df = pd.DataFrame(logs_state)
406
  return df
407
- return pd.DataFrame()
408
-
 
 
 
 
 
 
 
 
 
 
 
409
  # Gradio interface setup
410
- with gr.Blocks(theme=themes.Default()) as demo:
411
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  gr.Markdown(LICENSE)
413
-
414
- with gr.Tabs():
415
- with gr.TabItem("Chat"):
416
- chatbot = gr.Chatbot()
417
- msg = gr.Textbox(label="Enter your question")
418
- with gr.Row():
419
- submit = gr.Button("Submit")
420
- clear = gr.Button("Clear")
421
- advanced = gr.Accordion("Advanced Settings", open=False)
422
- with advanced:
423
- system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6)
424
- max_new_tokens = gr.Slider(minimum=1, maximum=MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS, step=1, label="Max New Tokens")
425
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, label="Temperature")
426
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.1, label="Top P")
427
- top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K")
428
-
429
- with gr.TabItem("Metrics"):
430
- metrics_df = gr.Dataframe(headers=['timestamp', 'query', 'response', 'response_length', 'generation_time', 'token_efficiency'])
431
-
432
- logs_state = gr.State(logs)
433
-
434
- def submit_fn(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state):
435
- gen, new_logs = generate(msg, history, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state)
436
- history.append((msg, ""))
437
- for partial in gen:
438
- history[-1] = (history[-1][0], partial)
439
- yield history, "", new_logs
440
- return history, "", new_logs
441
-
442
- submit.click(
443
- submit_fn,
444
- inputs=[msg, chatbot, system_prompt, max_new_tokens, temperature, top_p, top_k, logs_state],
445
- outputs=[chatbot, msg, logs_state],
446
- queue=False
447
- ).then(
448
- update_logs,
449
- inputs=[logs_state],
450
- outputs=[metrics_df]
451
  )
452
-
453
- clear.click(lambda: ([], []), None, (chatbot, logs_state))
454
-
455
- demo.launch()
 
11
  import logging
12
  import pandas as pd
13
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
 
17
  # Install llama-cpp-python with appropriate backend
18
  try:
19
  from llama_cpp import Llama
 
25
  else:
26
  logger.info("Installing llama-cpp-python without additional flags.")
27
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
28
+ from llama_cpp import Llama
 
29
  # Install yfinance if not present (for CAGR calculations)
30
  try:
31
  import yfinance as yf
32
  except ModuleNotFoundError:
33
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
34
+ import yfinance as yf
 
35
  # Import pandas for handling DataFrame column structures
36
  import pandas as pd
 
37
  # Additional imports for visualization and file handling
38
  try:
39
  import matplotlib.pyplot as plt
40
  from PIL import Image
41
  import io
42
  except ModuleNotFoundError:
43
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "matplotlib", "pillow"])
44
+ import matplotlib.pyplot as plt
45
+ from PIL import Image
46
+ import io
 
 
47
  MAX_MAX_NEW_TOKENS = 512
48
  DEFAULT_MAX_NEW_TOKENS = 512
49
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
 
50
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
51
+ This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.
52
+ <p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
53
+ LICENSE = """<p/>
54
+ ---
55
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
 
56
  # Load the model (skip fine-tuning for faster startup)
57
  try:
58
  model_path = hf_hub_download(
 
63
  llm = Llama(
64
  model_path=model_path,
65
  n_ctx=1024,
66
+ n_batch=1024, # Increased for faster processing
67
  n_threads=multiprocessing.cpu_count(),
68
  n_gpu_layers=n_gpu_layers,
69
+ chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
70
  )
71
  logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
72
  # Warm up the model for faster initial inference
 
75
  except Exception as e:
76
  logger.error(f"Error loading model: {str(e)}")
77
  raise
 
78
  # Register explicit close for llm to avoid destructor error
79
  atexit.register(llm.close)
 
80
  DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
81
  Always respond exclusively in English. Use bullet points for clarity.
 
82
  Example:
83
+ User: average return for TSLA between 2010 and 2020
84
  Assistant:
85
+ - TSLA CAGR (2010-2020): ~63.01%
86
  - Represents average annual return with compounding
87
+ - Past performance not indicative of future results
88
+ - Consult a financial advisor"""
89
+ # Company name to ticker mapping (expand as needed)
90
+ COMPANY_TO_TICKER = {
91
+ "opendoor": "OPEN",
92
+ "tesla": "TSLA",
93
+ "apple": "AAPL",
94
+ "amazon": "AMZN",
95
+ # Add more mappings for common companies
96
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def generate(
98
  message: str,
99
+ chat_history: list[dict],
100
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
101
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
102
+ temperature: float = 0.6,
103
+ top_p: float = 0.9,
104
+ top_k: int = 50,
105
+ repetition_penalty: float = 1.2,
106
+ ) -> Iterator[str]:
107
+ logger.info(f"Generating response for message: {message}")
108
+ lower_message = message.lower().strip()
109
+ if lower_message in ["hi", "hello"]:
110
+ response = "I'm FinChat, your financial advisor. Ask me anything finance-related!"
111
+ logger.info("Quick response for 'hi'/'hello' generated.")
112
+ yield response
113
+ return
114
+ if "what is cagr" in lower_message:
115
+ response = """- CAGR stands for Compound Annual Growth Rate.
116
+ - It measures the mean annual growth rate of an investment over a specified period longer than one year, accounting for compounding.
117
+ - Formula: CAGR = (Ending Value / Beginning Value)^(1 / Number of Years) - 1
118
+ - Useful for comparing investments over time.
119
+ - Past performance not indicative of future results. Consult a financial advisor."""
120
+ logger.info("Quick response for 'what is cagr' generated.")
121
+ yield response
122
+ return
123
+ # Check for compound interest queries
124
+ compound_match = re.search(r'(?:save|invest|deposit)\s*\$?([\d,]+(?:\.\d+)?)\s*(?:right now|today)?\s*(?:under|at)\s*([\d.]+)%\s*(?:interest|rate)?\s*(?:annually|per year)?\s*over\s*(\d+)\s*years?', lower_message)
125
+ if compound_match:
126
+ try:
127
+ principal_str = compound_match.group(1).replace(',', '')
128
+ principal = float(principal_str)
129
+ rate = float(compound_match.group(2)) / 100
130
+ years = int(compound_match.group(3))
131
+ if principal <= 0 or rate < 0 or years <= 0:
132
+ yield "Invalid input values: principal, rate, and years must be positive."
133
+ return
134
+ balance = principal * (1 + rate) ** years
135
+ response = (
136
+ f"- Starting with ${principal:,.2f} at {rate*100:.2f}% annual interest, compounded annually over {years} years.\n"
137
+ f"- Projected balance in year {years}: ${balance:,.2f}\n"
138
+ f"- Assumptions: Annual compounding; no additional deposits or withdrawals.\n"
139
+ f"- This is a projection; actual results may vary. Consult a financial advisor."
140
+ )
141
+ logger.info("Compound interest response generated.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  yield response
143
+ return
144
+ except ValueError as ve:
145
+ logger.error(f"Error parsing compound interest query: {str(ve)}")
146
+ yield "Error parsing query: Please ensure amount, rate, and years are valid numbers."
147
+ return
148
+ # Check for CAGR/average return queries (use re.search for flexible matching)
149
+ match = re.search(r'(?:average return|cagr) for ([\w\s,]+(?:and [\w\s,]+)?) between (\d{4}) and (\d{4})', lower_message)
150
+ if match:
151
+ tickers_str, start_year, end_year = match.groups()
152
+ tickers = [t.strip().upper() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
153
+ # Apply company-to-ticker mapping
154
+ for i in range(len(tickers)):
155
+ lower_ticker = tickers[i].lower()
156
+ if lower_ticker in COMPANY_TO_TICKER:
157
+ tickers[i] = COMPANY_TO_TICKER[lower_ticker]
158
+ responses = []
159
+ if int(end_year) <= int(start_year):
160
+ yield "The specified time period is invalid (end year must be after start year)."
161
+ return
162
+ for ticker in tickers:
163
  try:
164
+ # Download data with adjusted close prices
165
+ data = yf.download(ticker, start=f"{start_year}-01-01", end=f"{end_year}-12-31", progress=False, auto_adjust=False)
166
+ # Handle potential MultiIndex columns in newer yfinance versions
167
+ if isinstance(data.columns, pd.MultiIndex):
168
+ data.columns = data.columns.droplevel(1)
169
+ if not data.empty:
170
+ # Check if 'Adj Close' column exists
171
+ if 'Adj Close' not in data.columns:
172
+ responses.append(f"- {ticker}: Error - Adjusted Close price data not available.")
173
+ logger.error(f"No 'Adj Close' column for {ticker}.")
174
+ continue
175
+ # Ensure data is not MultiIndex for single ticker (already handled)
176
+ initial = data['Adj Close'].iloc[0]
177
+ final = data['Adj Close'].iloc[-1]
178
+ start_date = data.index[0]
179
+ end_date = data.index[-1]
180
+ days = (end_date - start_date).days
181
+ years = days / 365.25
182
+ if years > 0 and pd.notna(initial) and pd.notna(final):
183
+ cagr = ((final / initial) ** (1 / years) - 1) * 100
184
+ responses.append(f"- {ticker}: ~{cagr:.2f}%")
185
+ else:
186
+ responses.append(f"- {ticker}: Invalid period or missing price data.")
187
+ else:
188
+ responses.append(f"- {ticker}: No historical data available between {start_year} and {end_year}.")
189
+ except Exception as e:
190
+ logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
191
+ responses.append(f"- {ticker}: Error calculating CAGR - {str(e)}")
192
+ full_response = f"CAGR for the requested stocks from {start_year} to {end_year}:\n" + "\n".join(responses) + "\n- Represents average annual returns with compounding\n- Past performance not indicative of future results\n- Consult a financial advisor"
193
+ full_response = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', full_response).strip() # Clean any trailing tokens
194
+ # Estimate token count to ensure response fits within max_new_tokens
195
+ response_tokens = len(llm.tokenize(full_response.encode("utf-8"), add_bos=False))
196
+ if response_tokens > max_new_tokens:
197
+ logger.warning(f"CAGR response tokens ({response_tokens}) exceed max_new_tokens ({max_new_tokens}). Truncating to first complete sentence.")
198
+ sentence_endings = ['.', '!', '?']
199
+ first_sentence_end = min([full_response.find(ending) + 1 for ending in sentence_endings if full_response.find(ending) != -1], default=len(full_response))
200
+ full_response = full_response[:first_sentence_end] if first_sentence_end > 0 else "Response truncated due to length; please increase Max New Tokens."
201
+ logger.info("CAGR response generated.")
202
+ yield full_response
203
+ return
204
+ # Build conversation messages (limit history to last 3 for speed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  conversation = [{"role": "system", "content": system_prompt}]
206
+ for msg in chat_history[-3:]: # Reduced from 5 to 3 for faster processing
207
+ if msg["role"] == "user":
208
+ conversation.append({"role": "user", "content": msg["content"]})
209
+ elif msg["role"] == "assistant":
210
+ conversation.append({"role": "assistant", "content": msg["content"]})
211
  conversation.append({"role": "user", "content": message})
212
+ # Approximate token length check and truncate if necessary
213
+ prompt_text = "\n".join(d["content"] for d in conversation)
214
+ input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
215
+ while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
216
+ logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit ({MAX_INPUT_TOKEN_LENGTH}). Truncating history.")
217
+ if len(conversation) > 2: # Preserve system prompt and current user message
218
+ conversation.pop(1) # Remove oldest user/assistant pair
219
+ prompt_text = "\n".join(d["content"] for d in conversation)
220
+ input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
221
+ else:
222
+ yield "Error: Input is too long even after truncation. Please shorten your query."
223
+ return
224
+ # Generate response with sentence boundary checking and token cleanup
225
+ try:
226
+ response = ""
227
+ sentence_buffer = ""
228
+ token_count = 0
229
+ stream = llm.create_chat_completion(
230
+ messages=conversation,
231
+ max_tokens=max_new_tokens,
232
+ temperature=temperature,
233
+ top_p=top_p,
234
+ top_k=top_k,
235
+ repeat_penalty=repetition_penalty,
236
+ stream=True
237
+ )
238
+ sentence_endings = ['.', '!', '?']
239
+ for chunk in stream:
240
+ delta = chunk["choices"][0]["delta"]
241
+ if "content" in delta and delta["content"] is not None:
242
+ # Clean the chunk by removing ChatML tokens or similar
243
+ cleaned_chunk = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', delta["content"])
244
+ if not cleaned_chunk:
245
+ continue
246
+ sentence_buffer += cleaned_chunk
247
+ response += cleaned_chunk
248
+ # Approximate token count for the chunk
249
+ chunk_tokens = len(llm.tokenize(cleaned_chunk.encode("utf-8"), add_bos=False))
250
+ token_count += chunk_tokens
251
+ # Check for sentence boundary
252
+ if any(sentence_buffer.strip().endswith(ending) for ending in sentence_endings):
253
+ yield response
254
+ sentence_buffer = "" # Clear buffer after yielding a complete sentence
255
+ # Removed early truncation to allow full token utilization
256
+ if chunk["choices"][0]["finish_reason"] is not None:
257
+ # Yield any remaining complete sentence in the buffer
258
+ if sentence_buffer.strip():
259
+ last_sentence_end = max([sentence_buffer.rfind(ending) for ending in sentence_endings if sentence_buffer.rfind(ending) != -1], default=-1)
260
+ if last_sentence_end != -1:
261
+ response = response[:response.rfind(sentence_buffer) + last_sentence_end + 1]
262
+ yield response
263
+ else:
264
+ yield response
265
+ else:
266
+ yield response
267
+ break
268
+ logger.info("Response generation completed.")
269
+ except ValueError as ve:
270
+ if "exceed context window" in str(ve):
271
+ yield "Error: Prompt too long for context window. Please try a shorter query or clear history."
272
+ else:
273
+ logger.error(f"Error during response generation: {str(ve)}")
274
+ yield f"Error generating response: {str(ve)}"
275
+ except Exception as e:
276
+ logger.error(f"Error during response generation: {str(e)}")
277
+ yield f"Error generating response: {str(e)}"
278
+ def process_portfolio(df, growth_rate):
279
+ if df is None or len(df) == 0:
280
+ return "", None
281
+ # Convert to DataFrame if needed
282
+ if not isinstance(df, pd.DataFrame):
283
+ df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
284
+ df = df.dropna(subset=["Ticker"])
285
+ portfolio = {}
286
+ for _, row in df.iterrows():
287
+ ticker = row["Ticker"].upper() if pd.notna(row["Ticker"]) else None
288
+ if not ticker:
289
+ continue
290
+ shares = float(row["Shares"]) if pd.notna(row["Shares"]) else 0
291
+ cost = float(row["Avg Cost"]) if pd.notna(row["Avg Cost"]) else 0
292
+ price = float(row["Current Price"]) if pd.notna(row["Current Price"]) else 0
293
+ value = shares * price
294
+ portfolio[ticker] = {'shares': shares, 'cost': cost, 'price': price, 'value': value}
295
+ if not portfolio:
296
+ return "", None
297
+ total_value_now = sum(v['value'] for v in portfolio.values())
298
+ allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
299
+ fig_alloc, ax_alloc = plt.subplots()
300
+ ax_alloc.pie(allocations.values(), labels=allocations.keys(), autopct='%1.1f%%')
301
+ ax_alloc.set_title('Portfolio Allocation')
302
+ buf_alloc = io.BytesIO()
303
+ fig_alloc.savefig(buf_alloc, format='png')
304
+ buf_alloc.seek(0)
305
+ chart_alloc = Image.open(buf_alloc)
306
+ plt.close(fig_alloc) # Close the figure to free memory
307
+ def project_value(value, years, rate):
308
+ return value * (1 + rate / 100) ** years
309
+ total_value_1yr = sum(project_value(v['value'], 1, growth_rate) for v in portfolio.values())
310
+ total_value_2yr = sum(project_value(v['value'], 2, growth_rate) for v in portfolio.values())
311
+ total_value_5yr = sum(project_value(v['value'], 5, growth_rate) for v in portfolio.values())
312
+ total_value_10yr = sum(project_value(v['value'], 10, growth_rate) for v in portfolio.values())
313
+ data_str = (
314
+ "User portfolio:\n" +
315
+ "\n".join(f"- {k}: {v['shares']} shares, avg cost {v['cost']}, current price {v['price']}, value ${v['value']:,.2f}" for k, v in portfolio.items()) +
316
+ f"\nTotal value now: ${total_value_now:,.2f}\nProjected (at {growth_rate}% annual growth):\n" +
317
+ f"- 1 year: ${total_value_1yr:,.2f}\n- 2 years: ${total_value_2yr:,.2f}\n- 5 years: ${total_value_5yr:,.2f}\n- 10 years: ${total_value_10yr:,.2f}"
318
  )
319
+ return data_str, chart_alloc
320
+ def fetch_current_prices(df):
321
+ if df is None or len(df) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  return df
323
+ # Convert to DataFrame if needed
324
+ if not isinstance(df, pd.DataFrame):
325
+ df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
326
+ for i in df.index:
327
+ ticker = df.at[i, "Ticker"]
328
+ if pd.notna(ticker) and ticker.strip():
329
+ try:
330
+ price = yf.Ticker(ticker.upper()).info.get('currentPrice', None)
331
+ if price is not None:
332
+ df.at[i, "Current Price"] = price
333
+ except Exception as e:
334
+ logger.warning(f"Failed to fetch price for {ticker}: {str(e)}")
335
+ return df
336
  # Gradio interface setup
337
+ with gr.Blocks(theme=themes.Soft(), css="""#chatbot {height: 800px; overflow: auto;}""") as demo:
338
  gr.Markdown(DESCRIPTION)
339
+ chatbot = gr.Chatbot(label="FinChat", type="messages")
340
+ msg = gr.Textbox(label="Ask a finance question", placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'", info="Enter your query here. Portfolio data will be appended if provided.")
341
+ with gr.Row():
342
+ submit = gr.Button("Submit", variant="primary")
343
+ clear = gr.Button("Clear")
344
+ gr.Examples(
345
+ examples=["What is CAGR?", "Average return for AAPL between 2010 and 2020", "Hi", "Explain compound interest"],
346
+ inputs=msg,
347
+ label="Example Queries"
348
+ )
349
+ with gr.Accordion("Enter Portfolio for Projections", open=False):
350
+ portfolio_df = gr.Dataframe(
351
+ headers=["Ticker", "Shares", "Avg Cost", "Current Price"],
352
+ datatype=["str", "number", "number", "number"],
353
+ row_count=3,
354
+ col_count=(4, "fixed"),
355
+ label="Portfolio Data",
356
+ interactive=True
357
+ )
358
+ gr.Markdown("Enter your stocks here. You can add more rows by editing the table.")
359
+ fetch_button = gr.Button("Fetch Current Prices", variant="secondary")
360
+ fetch_button.click(fetch_current_prices, inputs=portfolio_df, outputs=portfolio_df)
361
+ growth_rate = gr.Slider(minimum=5, maximum=50, step=5, value=10, label="Annual Growth Rate (%)", interactive=True, info="Select the assumed annual growth rate for projections.")
362
+ growth_rate_label = gr.Markdown("**Selected Growth Rate: 10%**")
363
+ with gr.Accordion("Advanced Settings", open=False):
364
+ system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6, info="Customize the AI's system prompt.")
365
+ temperature = gr.Slider(label="Temperature", value=0.6, minimum=0.0, maximum=1.0, step=0.05, info="Controls randomness: lower is more deterministic.")
366
+ top_p = gr.Slider(label="Top P", value=0.9, minimum=0.0, maximum=1.0, step=0.05, info="Nucleus sampling: higher includes more diverse tokens.")
367
+ top_k = gr.Slider(label="Top K", value=50, minimum=1, maximum=100, step=1, info="Top-K sampling: limits to top K tokens.")
368
+ repetition_penalty = gr.Slider(label="Repetition Penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, info="Penalizes repeated tokens.")
369
+ max_new_tokens = gr.Slider(label="Max New Tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, info="Maximum length of generated response.")
370
  gr.Markdown(LICENSE)
371
+ def update_growth_rate_label(growth_rate):
372
+ return f"**Selected Growth Rate: {growth_rate}%**"
373
+ def user(message, history):
374
+ if not message:
375
+ return "", history
376
+ return "", history + [{"role": "user", "content": message}]
377
+ def bot(history, sys_prompt, temp, tp, tk, rp, mnt, portfolio_df, growth_rate):
378
+ if not history:
379
+ logger.warning("History is empty, initializing with user message.")
380
+ history = [{"role": "user", "content": ""}]
381
+ message = history[-1]["content"]
382
+ portfolio_data, chart_alloc = process_portfolio(portfolio_df, growth_rate)
383
+ message += "\n" + portfolio_data
384
+ history[-1]["content"] = message
385
+ history.append({"role": "assistant", "content": ""})
386
+ for new_text in generate(message, history[:-1], sys_prompt, mnt, temp, tp, tk, rp):
387
+ history[-1]["content"] = new_text
388
+ yield history, f"**Selected Growth Rate: {growth_rate}%**"
389
+ if chart_alloc:
390
+ history.append({"role": "assistant", "content": "", "image": chart_alloc})
391
+ yield history, f"**Selected Growth Rate: {growth_rate}%**"
392
+ growth_rate.change(update_growth_rate_label, inputs=growth_rate, outputs=growth_rate_label)
393
+ submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
394
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
395
+ )
396
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
397
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
 
 
 
 
 
 
 
 
 
 
 
398
  )
399
+ clear.click(lambda: [], None, chatbot, queue=False)
400
+ demo.queue(max_size=128).launch()