AnilNiraula commited on
Commit
df993b7
ยท
verified ยท
1 Parent(s): 1d02fc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +541 -231
app.py CHANGED
@@ -5,15 +5,23 @@ import re
5
  import multiprocessing
6
  import atexit
7
  from collections.abc import Iterator
 
 
 
8
  import gradio as gr
9
  import gradio.themes as themes
10
  from huggingface_hub import hf_hub_download, login
11
  import logging
12
  import pandas as pd
13
  import torch
14
- # Set up logging
15
- logging.basicConfig(level=logging.INFO)
 
 
 
 
16
  logger = logging.getLogger(__name__)
 
17
  # Install llama-cpp-python with appropriate backend
18
  try:
19
  from llama_cpp import Llama
@@ -26,15 +34,15 @@ except ModuleNotFoundError:
26
  logger.info("Installing llama-cpp-python without additional flags.")
27
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
28
  from llama_cpp import Llama
29
- # Install yfinance if not present (for CAGR calculations)
 
30
  try:
31
  import yfinance as yf
32
  except ModuleNotFoundError:
33
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
34
  import yfinance as yf
35
- # Import pandas for handling DataFrame column structures
36
- import pandas as pd
37
- # Additional imports for visualization and file handling
38
  try:
39
  import matplotlib.pyplot as plt
40
  from PIL import Image
@@ -44,16 +52,58 @@ except ModuleNotFoundError:
44
  import matplotlib.pyplot as plt
45
  from PIL import Image
46
  import io
 
 
47
  MAX_MAX_NEW_TOKENS = 512
48
- DEFAULT_MAX_NEW_TOKENS = 512
49
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
 
 
 
50
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
51
- This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries. Duplicate this Space for customization or queue-free deployment.
52
- <p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Inference is heavily optimized for responses in under 10 seconds for simple queries, with output limited to 250 tokens maximum. For longer responses, increase 'Max New Tokens' in Advanced Settings. Brief delays may occur in free-tier environments due to shared resources, but typical generation speeds are improved with the smaller model.</p>"""
 
53
  LICENSE = """<p/>
54
  ---
55
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
56
- # Load the model (skip fine-tuning for faster startup)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
  model_path = hf_hub_download(
59
  repo_id="TheBloke/phi-2-GGUF",
@@ -63,37 +113,143 @@ try:
63
  llm = Llama(
64
  model_path=model_path,
65
  n_ctx=1024,
66
- n_batch=1024, # Increased for faster processing
67
  n_threads=multiprocessing.cpu_count(),
68
  n_gpu_layers=n_gpu_layers,
69
- chat_format="chatml" # Phi-2 uses ChatML format in llama.cpp
70
  )
71
  logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
72
- # Warm up the model for faster initial inference
 
73
  llm("Warm-up prompt", max_tokens=1, echo=False)
74
  logger.info("Model warm-up completed.")
75
  except Exception as e:
76
  logger.error(f"Error loading model: {str(e)}")
77
  raise
78
- # Register explicit close for llm to avoid destructor error
79
  atexit.register(llm.close)
80
- DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
81
- Always respond exclusively in English. Use bullet points for clarity.
82
- Example:
83
- User: average return for TSLA between 2010 and 2020
84
- Assistant:
85
- - TSLA CAGR (2010-2020): ~63.01%
86
- - Represents average annual return with compounding
87
- - Past performance not indicative of future results
88
- - Consult a financial advisor"""
89
- # Company name to ticker mapping (expand as needed)
90
- COMPANY_TO_TICKER = {
91
- "opendoor": "OPEN",
92
- "tesla": "TSLA",
93
- "apple": "AAPL",
94
- "amazon": "AMZN",
95
- # Add more mappings for common companies
96
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def generate(
98
  message: str,
99
  chat_history: list[dict],
@@ -104,128 +260,143 @@ def generate(
104
  top_k: int = 50,
105
  repetition_penalty: float = 1.2,
106
  ) -> Iterator[str]:
107
- logger.info(f"Generating response for message: {message}")
 
 
108
  lower_message = message.lower().strip()
109
- if lower_message in ["hi", "hello"]:
110
- response = "I'm FinChat, your financial advisor. Ask me anything finance-related!"
111
- logger.info("Quick response for 'hi'/'hello' generated.")
 
 
112
  yield response
113
  return
 
114
  if "what is cagr" in lower_message:
115
  response = """- CAGR stands for Compound Annual Growth Rate.
116
  - It measures the mean annual growth rate of an investment over a specified period longer than one year, accounting for compounding.
117
  - Formula: CAGR = (Ending Value / Beginning Value)^(1 / Number of Years) - 1
118
  - Useful for comparing investments over time.
119
- - Past performance not indicative of future results. Consult a financial advisor."""
120
- logger.info("Quick response for 'what is cagr' generated.")
121
  yield response
122
  return
123
- # Check for compound interest queries
124
- compound_match = re.search(r'(?:save|invest|deposit)\s*\$?([\d,]+(?:\.\d+)?)\s*(?:right now|today)?\s*(?:under|at)\s*([\d.]+)%\s*(?:interest|rate)?\s*(?:annually|per year)?\s*over\s*(\d+)\s*years?', lower_message)
 
125
  if compound_match:
126
  try:
127
  principal_str = compound_match.group(1).replace(',', '')
128
  principal = float(principal_str)
129
  rate = float(compound_match.group(2)) / 100
130
  years = int(compound_match.group(3))
131
- if principal <= 0 or rate < 0 or years <= 0:
132
- yield "Invalid input values: principal, rate, and years must be positive."
 
 
 
 
 
 
 
133
  return
 
134
  balance = principal * (1 + rate) ** years
 
 
135
  response = (
136
- f"- Starting with ${principal:,.2f} at {rate*100:.2f}% annual interest, compounded annually over {years} years.\n"
137
- f"- Projected balance in year {years}: ${balance:,.2f}\n"
138
- f"- Assumptions: Annual compounding; no additional deposits or withdrawals.\n"
139
- f"- This is a projection; actual results may vary. Consult a financial advisor."
 
 
 
 
 
 
140
  )
141
- logger.info("Compound interest response generated.")
142
  yield response
143
  return
144
  except ValueError as ve:
145
  logger.error(f"Error parsing compound interest query: {str(ve)}")
146
- yield "Error parsing query: Please ensure amount, rate, and years are valid numbers."
147
  return
148
- # Check for CAGR/average return queries (use re.search for flexible matching)
149
- match = re.search(r'(?:average return|cagr) for ([\w\s,]+(?:and [\w\s,]+)?) between (\d{4}) and (\d{4})', lower_message)
150
- if match:
151
- tickers_str, start_year, end_year = match.groups()
152
- tickers = [t.strip().upper() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
 
 
 
 
 
 
153
  # Apply company-to-ticker mapping
154
- for i in range(len(tickers)):
155
- lower_ticker = tickers[i].lower()
 
156
  if lower_ticker in COMPANY_TO_TICKER:
157
- tickers[i] = COMPANY_TO_TICKER[lower_ticker]
158
- responses = []
159
- if int(end_year) <= int(start_year):
160
- yield "The specified time period is invalid (end year must be after start year)."
 
 
 
 
161
  return
162
- for ticker in tickers:
163
- try:
164
- # Download data with adjusted close prices
165
- data = yf.download(ticker, start=f"{start_year}-01-01", end=f"{end_year}-12-31", progress=False, auto_adjust=False)
166
- # Handle potential MultiIndex columns in newer yfinance versions
167
- if isinstance(data.columns, pd.MultiIndex):
168
- data.columns = data.columns.droplevel(1)
169
- if not data.empty:
170
- # Check if 'Adj Close' column exists
171
- if 'Adj Close' not in data.columns:
172
- responses.append(f"- {ticker}: Error - Adjusted Close price data not available.")
173
- logger.error(f"No 'Adj Close' column for {ticker}.")
174
- continue
175
- # Ensure data is not MultiIndex for single ticker (already handled)
176
- initial = data['Adj Close'].iloc[0]
177
- final = data['Adj Close'].iloc[-1]
178
- start_date = data.index[0]
179
- end_date = data.index[-1]
180
- days = (end_date - start_date).days
181
- years = days / 365.25
182
- if years > 0 and pd.notna(initial) and pd.notna(final):
183
- cagr = ((final / initial) ** (1 / years) - 1) * 100
184
- responses.append(f"- {ticker}: ~{cagr:.2f}%")
185
- else:
186
- responses.append(f"- {ticker}: Invalid period or missing price data.")
187
- else:
188
- responses.append(f"- {ticker}: No historical data available between {start_year} and {end_year}.")
189
- except Exception as e:
190
- logger.error(f"Error calculating CAGR for {ticker}: {str(e)}")
191
- responses.append(f"- {ticker}: Error calculating CAGR - {str(e)}")
192
- full_response = f"CAGR for the requested stocks from {start_year} to {end_year}:\n" + "\n".join(responses) + "\n- Represents average annual returns with compounding\n- Past performance not indicative of future results\n- Consult a financial advisor"
193
- full_response = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', full_response).strip() # Clean any trailing tokens
194
- # Estimate token count to ensure response fits within max_new_tokens
195
- response_tokens = len(llm.tokenize(full_response.encode("utf-8"), add_bos=False))
196
- if response_tokens > max_new_tokens:
197
- logger.warning(f"CAGR response tokens ({response_tokens}) exceed max_new_tokens ({max_new_tokens}). Truncating to first complete sentence.")
198
- sentence_endings = ['.', '!', '?']
199
- first_sentence_end = min([full_response.find(ending) + 1 for ending in sentence_endings if full_response.find(ending) != -1], default=len(full_response))
200
- full_response = full_response[:first_sentence_end] if first_sentence_end > 0 else "Response truncated due to length; please increase Max New Tokens."
201
- logger.info("CAGR response generated.")
202
  yield full_response
203
  return
204
- # Build conversation messages (limit history to last 3 for speed)
 
205
  conversation = [{"role": "system", "content": system_prompt}]
206
- for msg in chat_history[-3:]: # Reduced from 5 to 3 for faster processing
207
- if msg["role"] == "user":
208
- conversation.append({"role": "user", "content": msg["content"]})
209
- elif msg["role"] == "assistant":
210
- conversation.append({"role": "assistant", "content": msg["content"]})
211
  conversation.append({"role": "user", "content": message})
212
- # Approximate token length check and truncate if necessary
 
213
  prompt_text = "\n".join(d["content"] for d in conversation)
214
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
 
215
  while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
216
- logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit ({MAX_INPUT_TOKEN_LENGTH}). Truncating history.")
217
- if len(conversation) > 2: # Preserve system prompt and current user message
218
- conversation.pop(1) # Remove oldest user/assistant pair
219
  prompt_text = "\n".join(d["content"] for d in conversation)
220
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
221
  else:
222
- yield "Error: Input is too long even after truncation. Please shorten your query."
223
  return
224
- # Generate response with sentence boundary checking and token cleanup
 
225
  try:
226
  response = ""
227
  sentence_buffer = ""
228
- token_count = 0
229
  stream = llm.create_chat_completion(
230
  messages=conversation,
231
  max_tokens=max_new_tokens,
@@ -235,166 +406,305 @@ def generate(
235
  repeat_penalty=repetition_penalty,
236
  stream=True
237
  )
 
238
  sentence_endings = ['.', '!', '?']
 
239
  for chunk in stream:
240
  delta = chunk["choices"][0]["delta"]
241
  if "content" in delta and delta["content"] is not None:
242
- # Clean the chunk by removing ChatML tokens or similar
243
- cleaned_chunk = re.sub(r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]', '', delta["content"])
 
 
 
244
  if not cleaned_chunk:
245
  continue
 
246
  sentence_buffer += cleaned_chunk
247
  response += cleaned_chunk
248
- # Approximate token count for the chunk
249
- chunk_tokens = len(llm.tokenize(cleaned_chunk.encode("utf-8"), add_bos=False))
250
- token_count += chunk_tokens
251
- # Check for sentence boundary
252
  if any(sentence_buffer.strip().endswith(ending) for ending in sentence_endings):
253
  yield response
254
- sentence_buffer = "" # Clear buffer after yielding a complete sentence
255
- # Removed early truncation to allow full token utilization
256
  if chunk["choices"][0]["finish_reason"] is not None:
257
- # Yield any remaining complete sentence in the buffer
258
  if sentence_buffer.strip():
259
- last_sentence_end = max([sentence_buffer.rfind(ending) for ending in sentence_endings if sentence_buffer.rfind(ending) != -1], default=-1)
260
- if last_sentence_end != -1:
261
- response = response[:response.rfind(sentence_buffer) + last_sentence_end + 1]
262
- yield response
263
- else:
264
- yield response
265
- else:
266
  yield response
267
  break
268
- logger.info("Response generation completed.")
 
 
 
269
  except ValueError as ve:
270
  if "exceed context window" in str(ve):
271
- yield "Error: Prompt too long for context window. Please try a shorter query or clear history."
272
  else:
273
- logger.error(f"Error during response generation: {str(ve)}")
274
- yield f"Error generating response: {str(ve)}"
275
  except Exception as e:
276
  logger.error(f"Error during response generation: {str(e)}")
277
- yield f"Error generating response: {str(e)}"
 
278
  def process_portfolio(df, growth_rate):
 
279
  if df is None or len(df) == 0:
280
  return "", None
281
- # Convert to DataFrame if needed
282
- if not isinstance(df, pd.DataFrame):
283
- df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
284
- df = df.dropna(subset=["Ticker"])
285
- portfolio = {}
286
- for _, row in df.iterrows():
287
- ticker = row["Ticker"].upper() if pd.notna(row["Ticker"]) else None
288
- if not ticker:
289
- continue
290
- shares = float(row["Shares"]) if pd.notna(row["Shares"]) else 0
291
- cost = float(row["Avg Cost"]) if pd.notna(row["Avg Cost"]) else 0
292
- price = float(row["Current Price"]) if pd.notna(row["Current Price"]) else 0
293
- value = shares * price
294
- portfolio[ticker] = {'shares': shares, 'cost': cost, 'price': price, 'value': value}
295
- if not portfolio:
296
- return "", None
297
- total_value_now = sum(v['value'] for v in portfolio.values())
298
- allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
299
- fig_alloc, ax_alloc = plt.subplots()
300
- ax_alloc.pie(allocations.values(), labels=allocations.keys(), autopct='%1.1f%%')
301
- ax_alloc.set_title('Portfolio Allocation')
302
- buf_alloc = io.BytesIO()
303
- fig_alloc.savefig(buf_alloc, format='png')
304
- buf_alloc.seek(0)
305
- chart_alloc = Image.open(buf_alloc)
306
- plt.close(fig_alloc) # Close the figure to free memory
307
- def project_value(value, years, rate):
308
- return value * (1 + rate / 100) ** years
309
- total_value_1yr = sum(project_value(v['value'], 1, growth_rate) for v in portfolio.values())
310
- total_value_2yr = sum(project_value(v['value'], 2, growth_rate) for v in portfolio.values())
311
- total_value_5yr = sum(project_value(v['value'], 5, growth_rate) for v in portfolio.values())
312
- total_value_10yr = sum(project_value(v['value'], 10, growth_rate) for v in portfolio.values())
313
- data_str = (
314
- "User portfolio:\n" +
315
- "\n".join(f"- {k}: {v['shares']} shares, avg cost {v['cost']}, current price {v['price']}, value ${v['value']:,.2f}" for k, v in portfolio.items()) +
316
- f"\nTotal value now: ${total_value_now:,.2f}\nProjected (at {growth_rate}% annual growth):\n" +
317
- f"- 1 year: ${total_value_1yr:,.2f}\n- 2 years: ${total_value_2yr:,.2f}\n- 5 years: ${total_value_5yr:,.2f}\n- 10 years: ${total_value_10yr:,.2f}"
318
- )
319
- return data_str, chart_alloc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  def fetch_current_prices(df):
 
321
  if df is None or len(df) == 0:
322
  return df
323
- # Convert to DataFrame if needed
324
- if not isinstance(df, pd.DataFrame):
325
- df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
326
- for i in df.index:
327
- ticker = df.at[i, "Ticker"]
328
- if pd.notna(ticker) and ticker.strip():
329
- try:
330
- price = yf.Ticker(ticker.upper()).info.get('currentPrice', None)
331
- if price is not None:
332
- df.at[i, "Current Price"] = price
333
- except Exception as e:
334
- logger.warning(f"Failed to fetch price for {ticker}: {str(e)}")
335
- return df
336
- # Gradio interface setup
337
- with gr.Blocks(theme=themes.Soft(), css="""#chatbot {height: 800px; overflow: auto;}""") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  gr.Markdown(DESCRIPTION)
 
339
  chatbot = gr.Chatbot(label="FinChat", type="messages")
340
- msg = gr.Textbox(label="Ask a finance question", placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'", info="Enter your query here. Portfolio data will be appended if provided.")
 
 
 
 
 
341
  with gr.Row():
342
  submit = gr.Button("Submit", variant="primary")
343
  clear = gr.Button("Clear")
 
344
  gr.Examples(
345
- examples=["What is CAGR?", "Average return for AAPL between 2010 and 2020", "Hi", "Explain compound interest"],
 
 
 
 
 
 
346
  inputs=msg,
347
  label="Example Queries"
348
  )
349
- with gr.Accordion("Enter Portfolio for Projections", open=False):
 
350
  portfolio_df = gr.Dataframe(
351
  headers=["Ticker", "Shares", "Avg Cost", "Current Price"],
352
  datatype=["str", "number", "number", "number"],
353
- row_count=3,
354
  col_count=(4, "fixed"),
355
  label="Portfolio Data",
356
  interactive=True
357
  )
358
- gr.Markdown("Enter your stocks here. You can add more rows by editing the table.")
359
- fetch_button = gr.Button("Fetch Current Prices", variant="secondary")
 
 
 
 
 
 
 
360
  fetch_button.click(fetch_current_prices, inputs=portfolio_df, outputs=portfolio_df)
361
- growth_rate = gr.Slider(minimum=5, maximum=50, step=5, value=10, label="Annual Growth Rate (%)", interactive=True, info="Select the assumed annual growth rate for projections.")
 
 
 
 
 
 
 
 
 
362
  growth_rate_label = gr.Markdown("**Selected Growth Rate: 10%**")
363
- with gr.Accordion("Advanced Settings", open=False):
364
- system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=6, info="Customize the AI's system prompt.")
365
- temperature = gr.Slider(label="Temperature", value=0.6, minimum=0.0, maximum=1.0, step=0.05, info="Controls randomness: lower is more deterministic.")
366
- top_p = gr.Slider(label="Top P", value=0.9, minimum=0.0, maximum=1.0, step=0.05, info="Nucleus sampling: higher includes more diverse tokens.")
367
- top_k = gr.Slider(label="Top K", value=50, minimum=1, maximum=100, step=1, info="Top-K sampling: limits to top K tokens.")
368
- repetition_penalty = gr.Slider(label="Repetition Penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, info="Penalizes repeated tokens.")
369
- max_new_tokens = gr.Slider(label="Max New Tokens", value=DEFAULT_MAX_NEW_TOKENS, minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, info="Maximum length of generated response.")
370
- gr.Markdown(LICENSE)
371
- def update_growth_rate_label(growth_rate):
372
- return f"**Selected Growth Rate: {growth_rate}%**"
373
- def user(message, history):
374
- if not message:
375
- return "", history
376
- return "", history + [{"role": "user", "content": message}]
377
- def bot(history, sys_prompt, temp, tp, tk, rp, mnt, portfolio_df, growth_rate):
378
- if not history:
379
- logger.warning("History is empty, initializing with user message.")
380
- history = [{"role": "user", "content": ""}]
381
- message = history[-1]["content"]
382
- portfolio_data, chart_alloc = process_portfolio(portfolio_df, growth_rate)
383
- message += "\n" + portfolio_data
384
- history[-1]["content"] = message
385
- history.append({"role": "assistant", "content": ""})
386
- for new_text in generate(message, history[:-1], sys_prompt, mnt, temp, tp, tk, rp):
387
- history[-1]["content"] = new_text
388
- yield history, f"**Selected Growth Rate: {growth_rate}%**"
389
- if chart_alloc:
390
- history.append({"role": "assistant", "content": "", "image": chart_alloc})
391
- yield history, f"**Selected Growth Rate: {growth_rate}%**"
392
- growth_rate.change(update_growth_rate_label, inputs=growth_rate, outputs=growth_rate_label)
393
- submit.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
394
- bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
395
- )
396
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
397
- bot, [chatbot, system_prompt, temperature, top_p, top_k, repetition_penalty, max_new_tokens, portfolio_df, growth_rate], [chatbot, growth_rate_label]
398
- )
399
- clear.click(lambda: [], None, chatbot, queue=False)
400
- demo.queue(max_size=128).launch()
 
5
  import multiprocessing
6
  import atexit
7
  from collections.abc import Iterator
8
+ from functools import lru_cache
9
+ import datetime
10
+ import time
11
  import gradio as gr
12
  import gradio.themes as themes
13
  from huggingface_hub import hf_hub_download, login
14
  import logging
15
  import pandas as pd
16
  import torch
17
+
18
+ # Set up logging with more detail
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
+ )
23
  logger = logging.getLogger(__name__)
24
+
25
  # Install llama-cpp-python with appropriate backend
26
  try:
27
  from llama_cpp import Llama
 
34
  logger.info("Installing llama-cpp-python without additional flags.")
35
  subprocess.check_call([sys.executable, "-m", "pip", "install", "llama-cpp-python", "--force-reinstall", "--upgrade", "--no-cache-dir"])
36
  from llama_cpp import Llama
37
+
38
+ # Install yfinance if not present
39
  try:
40
  import yfinance as yf
41
  except ModuleNotFoundError:
42
  subprocess.check_call([sys.executable, "-m", "pip", "install", "yfinance"])
43
  import yfinance as yf
44
+
45
+ # Additional imports for visualization
 
46
  try:
47
  import matplotlib.pyplot as plt
48
  from PIL import Image
 
52
  import matplotlib.pyplot as plt
53
  from PIL import Image
54
  import io
55
+
56
+ # Constants
57
  MAX_MAX_NEW_TOKENS = 512
58
+ DEFAULT_MAX_NEW_TOKENS = 256
59
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
60
+ YFINANCE_TIMEOUT = 10
61
+ CACHE_EXPIRY_HOURS = 1
62
+
63
  DESCRIPTION = """# FinChat: Investing Q&A (Optimized for Speed)
64
+ This application delivers an interactive chat interface powered by a highly efficient, small AI model adapted for addressing investing and finance inquiries through specialized prompt engineering. It ensures rapid, reasoned responses to user queries.
65
+ <p>Running on CPU or GPU if available. Using Phi-2 model for faster inference. Includes performance optimizations with caching and improved error handling.</p>"""
66
+
67
  LICENSE = """<p/>
68
  ---
69
  This application employs the Phi-2 model, governed by Microsoft's Terms of Use. Refer to the [model card](https://huggingface.co/TheBloke/phi-2-GGUF) for details."""
70
+
71
+ DEFAULT_SYSTEM_PROMPT = """You are FinChat, a knowledgeable AI assistant specializing in investing and finance. Provide accurate, helpful, reasoned, and concise answers to investing questions. Always base responses on reliable information and advise users to consult professionals for personalized advice.
72
+ Always respond exclusively in English. Use bullet points for clarity.
73
+ Example:
74
+ User: average return for TSLA between 2010 and 2020
75
+ Assistant:
76
+ - TSLA CAGR (2010-2020): ~63.01%
77
+ - Represents average annual return with compounding
78
+ - Past performance not indicative of future results
79
+ - Consult a financial advisor"""
80
+
81
+ # Company name to ticker mapping
82
+ COMPANY_TO_TICKER = {
83
+ "opendoor": "OPEN",
84
+ "tesla": "TSLA",
85
+ "apple": "AAPL",
86
+ "amazon": "AMZN",
87
+ "microsoft": "MSFT",
88
+ "google": "GOOGL",
89
+ "facebook": "META",
90
+ "meta": "META",
91
+ "nvidia": "NVDA",
92
+ "netflix": "NFLX",
93
+ }
94
+
95
+ # Compiled regex patterns for better performance
96
+ CAGR_PATTERN = re.compile(
97
+ r'(?:average\s+return|cagr)\s+(?:for\s+)?([\w\s,]+(?:and\s+[\w\s,]+)?)\s+(?:between|from)\s+(\d{4})\s+(?:and|to)\s+(\d{4})',
98
+ re.IGNORECASE
99
+ )
100
+
101
+ COMPOUND_INTEREST_PATTERN = re.compile(
102
+ r'(?:save|invest|deposit)\s*\$?([\d,]+(?:\.\d+)?)\s*(?:right now|today)?\s*(?:under|at)\s*([\d.]+)%\s*(?:interest|rate)?\s*(?:annually|per year)?\s*(?:over|for)\s*(\d+)\s*years?',
103
+ re.IGNORECASE
104
+ )
105
+
106
+ # Load the model
107
  try:
108
  model_path = hf_hub_download(
109
  repo_id="TheBloke/phi-2-GGUF",
 
113
  llm = Llama(
114
  model_path=model_path,
115
  n_ctx=1024,
116
+ n_batch=1024,
117
  n_threads=multiprocessing.cpu_count(),
118
  n_gpu_layers=n_gpu_layers,
119
+ chat_format="chatml"
120
  )
121
  logger.info(f"Model loaded successfully with n_gpu_layers={n_gpu_layers}.")
122
+
123
+ # Warm up the model
124
  llm("Warm-up prompt", max_tokens=1, echo=False)
125
  logger.info("Model warm-up completed.")
126
  except Exception as e:
127
  logger.error(f"Error loading model: {str(e)}")
128
  raise
129
+
130
  atexit.register(llm.close)
131
+
132
+ # Cache for stock data with timestamp
133
+ _stock_cache = {}
134
+ _cache_timestamps = {}
135
+
136
+ def sanitize_ticker(ticker):
137
+ """Sanitize ticker input to prevent injection and validate format"""
138
+ if not ticker or not isinstance(ticker, str):
139
+ return None
140
+ cleaned = re.sub(r'[^A-Z0-9.-]', '', ticker.upper().strip())
141
+ if len(cleaned) > 10 or len(cleaned) == 0:
142
+ return None
143
+ return cleaned
144
+
145
+ def validate_year_range(start_year, end_year):
146
+ """Validate year inputs"""
147
+ try:
148
+ start = int(start_year)
149
+ end = int(end_year)
150
+ current_year = datetime.datetime.now().year
151
+
152
+ if not (1900 <= start <= current_year):
153
+ return False, f"Start year must be between 1900 and {current_year}"
154
+ if not (1900 <= end <= current_year + 1):
155
+ return False, f"End year must be between 1900 and {current_year + 1}"
156
+ if end <= start:
157
+ return False, "End year must be after start year"
158
+
159
+ return True, "Valid"
160
+ except ValueError:
161
+ return False, "Years must be valid integers"
162
+
163
+ @lru_cache(maxsize=100)
164
+ def get_stock_data_cached(ticker, start_date, end_date, cache_key):
165
+ """Cache stock data to avoid repeated API calls. cache_key includes hour for expiry."""
166
+ try:
167
+ logger.info(f"Fetching data for {ticker} from {start_date} to {end_date}")
168
+ data = yf.download(
169
+ ticker,
170
+ start=start_date,
171
+ end=end_date,
172
+ progress=False,
173
+ auto_adjust=False,
174
+ timeout=YFINANCE_TIMEOUT
175
+ )
176
+ return data
177
+ except Exception as e:
178
+ logger.error(f"Error fetching data for {ticker}: {str(e)}")
179
+ return None
180
+
181
+ def get_current_cache_key():
182
+ """Generate cache key that expires every hour"""
183
+ now = datetime.datetime.now()
184
+ return f"{now.year}{now.month}{now.day}{now.hour}"
185
+
186
+ def calculate_cagr(ticker, start_year, end_year):
187
+ """Calculate CAGR for a ticker with error handling"""
188
+ try:
189
+ # Sanitize ticker
190
+ clean_ticker = sanitize_ticker(ticker)
191
+ if not clean_ticker:
192
+ return f"- {ticker}: Invalid ticker format"
193
+
194
+ # Validate years
195
+ valid, msg = validate_year_range(start_year, end_year)
196
+ if not valid:
197
+ return f"- {clean_ticker}: {msg}"
198
+
199
+ # Get cached data
200
+ cache_key = get_current_cache_key()
201
+ data = get_stock_data_cached(
202
+ clean_ticker,
203
+ f"{start_year}-01-01",
204
+ f"{end_year}-12-31",
205
+ cache_key
206
+ )
207
+
208
+ if data is None:
209
+ return f"- {clean_ticker}: Error fetching data (API timeout or network issue)"
210
+
211
+ if data.empty:
212
+ return f"- {clean_ticker}: No historical data available between {start_year} and {end_year}"
213
+
214
+ # Handle MultiIndex columns
215
+ if isinstance(data.columns, pd.MultiIndex):
216
+ data.columns = data.columns.droplevel(1)
217
+
218
+ # Check for Adj Close column
219
+ if 'Adj Close' not in data.columns:
220
+ return f"- {clean_ticker}: Adjusted Close price data not available"
221
+
222
+ # Calculate CAGR
223
+ initial = data['Adj Close'].iloc[0]
224
+ final = data['Adj Close'].iloc[-1]
225
+
226
+ if pd.isna(initial) or pd.isna(final):
227
+ return f"- {clean_ticker}: Missing price data"
228
+
229
+ if initial <= 0 or final <= 0:
230
+ return f"- {clean_ticker}: Invalid price data (negative or zero values)"
231
+
232
+ start_date = data.index[0]
233
+ end_date = data.index[-1]
234
+ days = (end_date - start_date).days
235
+ years = days / 365.25
236
+
237
+ if years <= 0:
238
+ return f"- {clean_ticker}: Invalid time period"
239
+
240
+ cagr = ((final / initial) ** (1 / years) - 1) * 100
241
+
242
+ # Add context about data quality
243
+ actual_start = start_date.strftime('%Y-%m-%d')
244
+ actual_end = end_date.strftime('%Y-%m-%d')
245
+ date_note = f" (data: {actual_start} to {actual_end})" if actual_start != f"{start_year}-01-01" else ""
246
+
247
+ return f"- {clean_ticker}: ~{cagr:.2f}%{date_note}"
248
+
249
+ except Exception as e:
250
+ logger.error(f"Unexpected error calculating CAGR for {ticker}: {str(e)}")
251
+ return f"- {ticker}: Calculation error - {str(e)}"
252
+
253
  def generate(
254
  message: str,
255
  chat_history: list[dict],
 
260
  top_k: int = 50,
261
  repetition_penalty: float = 1.2,
262
  ) -> Iterator[str]:
263
+ start_time = time.time()
264
+ logger.info(f"Generating response for message: {message[:100]}...")
265
+
266
  lower_message = message.lower().strip()
267
+
268
+ # Quick responses for common queries
269
+ if lower_message in ["hi", "hello", "hey"]:
270
+ response = "Hello! I'm FinChat, your financial advisor. Ask me anything about investing, stocks, CAGR, compound interest, or portfolio analysis!"
271
+ logger.info(f"Quick response generated in {time.time() - start_time:.2f}s")
272
  yield response
273
  return
274
+
275
  if "what is cagr" in lower_message:
276
  response = """- CAGR stands for Compound Annual Growth Rate.
277
  - It measures the mean annual growth rate of an investment over a specified period longer than one year, accounting for compounding.
278
  - Formula: CAGR = (Ending Value / Beginning Value)^(1 / Number of Years) - 1
279
  - Useful for comparing investments over time.
280
+ - Past performance is not indicative of future results. Consult a financial advisor."""
281
+ logger.info(f"Quick response generated in {time.time() - start_time:.2f}s")
282
  yield response
283
  return
284
+
285
+ # Compound interest calculation
286
+ compound_match = COMPOUND_INTEREST_PATTERN.search(lower_message)
287
  if compound_match:
288
  try:
289
  principal_str = compound_match.group(1).replace(',', '')
290
  principal = float(principal_str)
291
  rate = float(compound_match.group(2)) / 100
292
  years = int(compound_match.group(3))
293
+
294
+ if principal <= 0:
295
+ yield "Error: Principal amount must be positive."
296
+ return
297
+ if rate < 0 or rate > 1:
298
+ yield "Error: Interest rate must be between 0% and 100%."
299
+ return
300
+ if years <= 0 or years > 100:
301
+ yield "Error: Years must be between 1 and 100."
302
  return
303
+
304
  balance = principal * (1 + rate) ** years
305
+ total_interest = balance - principal
306
+
307
  response = (
308
+ f"**Compound Interest Calculation**\n\n"
309
+ f"- Starting Principal: ${principal:,.2f}\n"
310
+ f"- Annual Interest Rate: {rate*100:.2f}%\n"
311
+ f"- Time Period: {years} years\n"
312
+ f"- Compounding: Annually\n\n"
313
+ f"**Results:**\n"
314
+ f"- Final Balance (Year {years}): ${balance:,.2f}\n"
315
+ f"- Total Interest Earned: ${total_interest:,.2f}\n"
316
+ f"- Total Growth: {((balance/principal - 1) * 100):.2f}%\n\n"
317
+ f"*Note: This assumes annual compounding with no additional deposits or withdrawals. Actual results may vary. Consult a financial advisor.*"
318
  )
319
+ logger.info(f"Compound interest calculated in {time.time() - start_time:.2f}s")
320
  yield response
321
  return
322
  except ValueError as ve:
323
  logger.error(f"Error parsing compound interest query: {str(ve)}")
324
+ yield "Error: Please ensure amount, rate, and years are valid numbers. Example: 'If I save $10000 at 5% interest over 10 years'"
325
  return
326
+ except Exception as e:
327
+ logger.error(f"Unexpected error in compound interest: {str(e)}")
328
+ yield f"Error calculating compound interest: {str(e)}"
329
+ return
330
+
331
+ # CAGR calculation with improved pattern matching
332
+ cagr_match = CAGR_PATTERN.search(lower_message)
333
+ if cagr_match:
334
+ tickers_str, start_year, end_year = cagr_match.groups()
335
+ tickers = [t.strip() for t in re.split(r',|\band\b', tickers_str) if t.strip()]
336
+
337
  # Apply company-to-ticker mapping
338
+ mapped_tickers = []
339
+ for ticker in tickers:
340
+ lower_ticker = ticker.lower()
341
  if lower_ticker in COMPANY_TO_TICKER:
342
+ mapped_tickers.append(COMPANY_TO_TICKER[lower_ticker])
343
+ else:
344
+ mapped_tickers.append(ticker.upper())
345
+
346
+ # Validate year range first
347
+ valid, msg = validate_year_range(start_year, end_year)
348
+ if not valid:
349
+ yield f"Error: {msg}"
350
  return
351
+
352
+ if len(mapped_tickers) > 10:
353
+ yield "Error: Too many tickers requested. Please limit to 10 tickers per query."
354
+ return
355
+
356
+ responses = []
357
+ for ticker in mapped_tickers:
358
+ result = calculate_cagr(ticker, start_year, end_year)
359
+ responses.append(result)
360
+
361
+ full_response = (
362
+ f"**CAGR Analysis ({start_year} - {end_year})**\n\n"
363
+ + "\n".join(responses) +
364
+ "\n\n*Notes:*\n"
365
+ "- CAGR represents average annual returns with compounding\n"
366
+ "- Based on adjusted closing prices\n"
367
+ "- Past performance is not indicative of future results\n"
368
+ "- Please consult a financial advisor for investment decisions"
369
+ )
370
+
371
+ logger.info(f"CAGR response generated in {time.time() - start_time:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  yield full_response
373
  return
374
+
375
+ # Build conversation for LLM
376
  conversation = [{"role": "system", "content": system_prompt}]
377
+ for msg in chat_history[-3:]:
378
+ if msg["role"] in ["user", "assistant"]:
379
+ conversation.append({"role": msg["role"], "content": msg["content"]})
 
 
380
  conversation.append({"role": "user", "content": message})
381
+
382
+ # Token length check with truncation
383
  prompt_text = "\n".join(d["content"] for d in conversation)
384
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
385
+
386
  while len(input_tokens) > MAX_INPUT_TOKEN_LENGTH:
387
+ logger.warning(f"Input tokens ({len(input_tokens)}) exceed limit. Truncating history.")
388
+ if len(conversation) > 2:
389
+ conversation.pop(1)
390
  prompt_text = "\n".join(d["content"] for d in conversation)
391
  input_tokens = llm.tokenize(prompt_text.encode("utf-8"), add_bos=False)
392
  else:
393
+ yield "Error: Input is too long. Please shorten your query or start a new conversation."
394
  return
395
+
396
+ # Generate response
397
  try:
398
  response = ""
399
  sentence_buffer = ""
 
400
  stream = llm.create_chat_completion(
401
  messages=conversation,
402
  max_tokens=max_new_tokens,
 
406
  repeat_penalty=repetition_penalty,
407
  stream=True
408
  )
409
+
410
  sentence_endings = ['.', '!', '?']
411
+
412
  for chunk in stream:
413
  delta = chunk["choices"][0]["delta"]
414
  if "content" in delta and delta["content"] is not None:
415
+ cleaned_chunk = re.sub(
416
+ r'<\|(?:im_start|im_end|system|user|assistant)\|>|</s>|\[END\]',
417
+ '',
418
+ delta["content"]
419
+ )
420
  if not cleaned_chunk:
421
  continue
422
+
423
  sentence_buffer += cleaned_chunk
424
  response += cleaned_chunk
425
+
 
 
 
426
  if any(sentence_buffer.strip().endswith(ending) for ending in sentence_endings):
427
  yield response
428
+ sentence_buffer = ""
429
+
430
  if chunk["choices"][0]["finish_reason"] is not None:
 
431
  if sentence_buffer.strip():
 
 
 
 
 
 
 
432
  yield response
433
  break
434
+
435
+ duration = time.time() - start_time
436
+ logger.info(f"LLM response generated in {duration:.2f}s")
437
+
438
  except ValueError as ve:
439
  if "exceed context window" in str(ve):
440
+ yield "Error: Input exceeds context window. Please try a shorter query."
441
  else:
442
+ logger.error(f"ValueError during generation: {str(ve)}")
443
+ yield f"Error: {str(ve)}"
444
  except Exception as e:
445
  logger.error(f"Error during response generation: {str(e)}")
446
+ yield f"Error generating response. Please try again or rephrase your question."
447
+
448
  def process_portfolio(df, growth_rate):
449
+ """Process portfolio with enhanced error handling and validation"""
450
  if df is None or len(df) == 0:
451
  return "", None
452
+
453
+ try:
454
+ if not isinstance(df, pd.DataFrame):
455
+ df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
456
+
457
+ # Validate and convert numeric columns
458
+ for col in ["Shares", "Avg Cost", "Current Price"]:
459
+ df[col] = pd.to_numeric(df[col], errors='coerce')
460
+
461
+ df = df.dropna(subset=["Ticker"])
462
+
463
+ portfolio = {}
464
+ errors = []
465
+
466
+ for idx, row in df.iterrows():
467
+ ticker = sanitize_ticker(row["Ticker"]) if pd.notna(row["Ticker"]) else None
468
+ if not ticker:
469
+ continue
470
+
471
+ shares = float(row["Shares"]) if pd.notna(row["Shares"]) else 0
472
+ cost = float(row["Avg Cost"]) if pd.notna(row["Avg Cost"]) else 0
473
+ price = float(row["Current Price"]) if pd.notna(row["Current Price"]) else 0
474
+
475
+ if shares <= 0:
476
+ errors.append(f"{ticker}: Invalid shares count")
477
+ continue
478
+ if price < 0 or cost < 0:
479
+ errors.append(f"{ticker}: Negative prices not allowed")
480
+ continue
481
+
482
+ value = shares * price
483
+ cost_basis = shares * cost
484
+ gain_loss = value - cost_basis
485
+ gain_loss_pct = (gain_loss / cost_basis * 100) if cost_basis > 0 else 0
486
+
487
+ portfolio[ticker] = {
488
+ 'shares': shares,
489
+ 'cost': cost,
490
+ 'price': price,
491
+ 'value': value,
492
+ 'cost_basis': cost_basis,
493
+ 'gain_loss': gain_loss,
494
+ 'gain_loss_pct': gain_loss_pct
495
+ }
496
+
497
+ if not portfolio:
498
+ return "No valid portfolio entries found. Please check your data.", None
499
+
500
+ total_value_now = sum(v['value'] for v in portfolio.values())
501
+ total_cost_basis = sum(v['cost_basis'] for v in portfolio.values())
502
+ total_gain_loss = total_value_now - total_cost_basis
503
+ total_gain_loss_pct = (total_gain_loss / total_cost_basis * 100) if total_cost_basis > 0 else 0
504
+
505
+ allocations = {k: v['value'] / total_value_now for k, v in portfolio.items()} if total_value_now > 0 else {}
506
+
507
+ # Create allocation pie chart
508
+ fig_alloc, ax_alloc = plt.subplots(figsize=(8, 6))
509
+ colors = plt.cm.Set3(range(len(allocations)))
510
+ ax_alloc.pie(
511
+ allocations.values(),
512
+ labels=allocations.keys(),
513
+ autopct='%1.1f%%',
514
+ colors=colors,
515
+ startangle=90
516
+ )
517
+ ax_alloc.set_title('Portfolio Allocation by Value', fontsize=14, fontweight='bold')
518
+
519
+ buf_alloc = io.BytesIO()
520
+ fig_alloc.savefig(buf_alloc, format='png', bbox_inches='tight', dpi=100)
521
+ buf_alloc.seek(0)
522
+ chart_alloc = Image.open(buf_alloc)
523
+ plt.close(fig_alloc)
524
+
525
+ # Project future values
526
+ def project_value(value, years, rate):
527
+ return value * (1 + rate / 100) ** years
528
+
529
+ projections = {
530
+ '1 year': sum(project_value(v['value'], 1, growth_rate) for v in portfolio.values()),
531
+ '2 years': sum(project_value(v['value'], 2, growth_rate) for v in portfolio.values()),
532
+ '5 years': sum(project_value(v['value'], 5, growth_rate) for v in portfolio.values()),
533
+ '10 years': sum(project_value(v['value'], 10, growth_rate) for v in portfolio.values())
534
+ }
535
+
536
+ # Build detailed report
537
+ data_str = "**๐Ÿ“Š Portfolio Analysis**\n\n"
538
+ data_str += "**Current Holdings:**\n"
539
+ for ticker, data in portfolio.items():
540
+ data_str += (
541
+ f"- {ticker}: {data['shares']:.2f} shares @ ${data['price']:.2f} "
542
+ f"(Cost: ${data['cost']:.2f}) = ${data['value']:,.2f} "
543
+ f"[{data['gain_loss_pct']:+.2f}%]\n"
544
+ )
545
+
546
+ data_str += f"\n**Portfolio Summary:**\n"
547
+ data_str += f"- Total Value: ${total_value_now:,.2f}\n"
548
+ data_str += f"- Total Cost Basis: ${total_cost_basis:,.2f}\n"
549
+ data_str += f"- Total Gain/Loss: ${total_gain_loss:+,.2f} ({total_gain_loss_pct:+.2f}%)\n"
550
+
551
+ data_str += f"\n**Projected Values (at {growth_rate}% annual growth):**\n"
552
+ for period, value in projections.items():
553
+ gain = value - total_value_now
554
+ data_str += f"- {period}: ${value:,.2f} (+${gain:,.2f})\n"
555
+
556
+ if errors:
557
+ data_str += f"\n**โš ๏ธ Warnings:**\n"
558
+ for error in errors:
559
+ data_str += f"- {error}\n"
560
+
561
+ data_str += "\n*Note: Projections assume constant growth rate and no additional contributions. Actual results will vary. Consult a financial advisor.*"
562
+
563
+ return data_str, chart_alloc
564
+
565
+ except Exception as e:
566
+ logger.error(f"Error processing portfolio: {str(e)}")
567
+ return f"Error processing portfolio: {str(e)}", None
568
+
569
  def fetch_current_prices(df):
570
+ """Fetch current prices with timeout and error handling"""
571
  if df is None or len(df) == 0:
572
  return df
573
+
574
+ try:
575
+ if not isinstance(df, pd.DataFrame):
576
+ df = pd.DataFrame(df, columns=["Ticker", "Shares", "Avg Cost", "Current Price"])
577
+
578
+ updated_count = 0
579
+ failed_tickers = []
580
+
581
+ for i in df.index:
582
+ ticker = df.at[i, "Ticker"]
583
+ if pd.notna(ticker) and ticker.strip():
584
+ clean_ticker = sanitize_ticker(ticker)
585
+ if not clean_ticker:
586
+ failed_tickers.append(f"{ticker} (invalid format)")
587
+ continue
588
+
589
+ try:
590
+ ticker_obj = yf.Ticker(clean_ticker)
591
+ info = ticker_obj.info
592
+ price = info.get('currentPrice') or info.get('regularMarketPrice')
593
+
594
+ if price is not None and price > 0:
595
+ df.at[i, "Current Price"] = price
596
+ updated_count += 1
597
+ else:
598
+ failed_tickers.append(f"{clean_ticker} (no price data)")
599
+ except Exception as e:
600
+ logger.warning(f"Failed to fetch price for {clean_ticker}: {str(e)}")
601
+ failed_tickers.append(f"{clean_ticker} ({str(e)[:30]})")
602
+
603
+ if updated_count > 0:
604
+ logger.info(f"Successfully updated {updated_count} prices")
605
+ if failed_tickers:
606
+ logger.warning(f"Failed to fetch: {', '.join(failed_tickers)}")
607
+
608
+ return df
609
+
610
+ except Exception as e:
611
+ logger.error(f"Error in fetch_current_prices: {str(e)}")
612
+ return df
613
+
614
+ # Gradio interface
615
+ with gr.Blocks(theme=themes.Soft(), css="""
616
+ #chatbot {height: 800px; overflow: auto;}
617
+ .performance-note {color: #666; font-size: 0.9em; font-style: italic;}
618
+ """) as demo:
619
  gr.Markdown(DESCRIPTION)
620
+
621
  chatbot = gr.Chatbot(label="FinChat", type="messages")
622
+ msg = gr.Textbox(
623
+ label="Ask a finance question",
624
+ placeholder="e.g., 'What is CAGR?' or 'Average return for AAPL between 2010 and 2020'",
625
+ info="Enter your query. Responses are cached for better performance."
626
+ )
627
+
628
  with gr.Row():
629
  submit = gr.Button("Submit", variant="primary")
630
  clear = gr.Button("Clear")
631
+
632
  gr.Examples(
633
+ examples=[
634
+ "What is CAGR?",
635
+ "Average return for AAPL between 2015 and 2023",
636
+ "Average return for TSLA and NVDA between 2018 and 2023",
637
+ "If I save $10000 at 5% interest over 10 years",
638
+ "Explain compound interest"
639
+ ],
640
  inputs=msg,
641
  label="Example Queries"
642
  )
643
+
644
+ with gr.Accordion("๐Ÿ“ˆ Enter Portfolio for Projections", open=False):
645
  portfolio_df = gr.Dataframe(
646
  headers=["Ticker", "Shares", "Avg Cost", "Current Price"],
647
  datatype=["str", "number", "number", "number"],
648
+ row_count=5,
649
  col_count=(4, "fixed"),
650
  label="Portfolio Data",
651
  interactive=True
652
  )
653
+ gr.Markdown("""
654
+ **Instructions:**
655
+ - Enter stock tickers (e.g., AAPL, TSLA)
656
+ - Fill in number of shares and your average cost per share
657
+ - Click 'Fetch Current Prices' to auto-populate current prices
658
+ - Adjust growth rate for future projections
659
+ """)
660
+
661
+ fetch_button = gr.Button("๐Ÿ”„ Fetch Current Prices", variant="secondary")
662
  fetch_button.click(fetch_current_prices, inputs=portfolio_df, outputs=portfolio_df)
663
+
664
+ growth_rate = gr.Slider(
665
+ minimum=0,
666
+ maximum=50,
667
+ step=1,
668
+ value=10,
669
+ label="Annual Growth Rate (%)",
670
+ interactive=True,
671
+ info="Expected annual return for projections (0-50%)"
672
+ )
673
  growth_rate_label = gr.Markdown("**Selected Growth Rate: 10%**")
674
+
675
+ with gr.Accordion("โš™๏ธ Advanced Settings", open=False):
676
+ system_prompt = gr.Textbox(
677
+ label="System Prompt",
678
+ value=DEFAULT_SYSTEM_PROMPT,
679
+ lines=6,
680
+ info="Customize the AI's behavior"
681
+ )
682
+ temperature = gr.Slider(
683
+ label="Temperature",
684
+ value=0.6,
685
+ minimum=0.0,
686
+ maximum=1.0,
687
+ step=0.05,
688
+ info="Lower = more focused, Higher = more creative"
689
+ )
690
+ top_p = gr.Slider(
691
+ label="Top P",
692
+ value=0.9,
693
+ minimum=0.0,
694
+ maximum=1.0,
695
+ step=0.05,
696
+ info="Nucleus sampling threshold"
697
+ )
698
+ top_k = gr.Slider(
699
+ label="Top K",
700
+ value=50,
701
+ minimum=1,
702
+ maximum=100,
703
+ step=1,
704
+ info="Limit to top K tokens"
705
+ )
706
+ repetition_penalty = gr.Slider(
707
+ label="Repetition Penalty",
708
+ value=1.2,
709
+ minimum=1.0,
710
+ maximum=