AnilNiraula commited on
Commit
f7cc8c3
·
verified ·
1 Parent(s): 948f3a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -1
app.py CHANGED
@@ -192,4 +192,315 @@ try:
192
  logger.info(f"Successfully loaded model: {model_name}")
193
  except Exception as e:
194
  logger.error(f"Error loading model/tokenizer: {e}")
195
- raise RuntimeError(f"Failed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  logger.info(f"Successfully loaded model: {model_name}")
193
  except Exception as e:
194
  logger.error(f"Error loading model/tokenizer: {e}")
195
+ raise RuntimeError(f"Failed to load model: {str(e)}")
196
+
197
+ # Prompt prefix
198
+ prompt_prefix = (
199
+ "You are FinChat, a financial advisor with expertise in stock market performance. Provide detailed, numbered list advice with clear reasoning for investing prompts, "
200
+ "including precise historical data when relevant (e.g., TSLA or S&P 500 returns for specific years or periods). For investment return queries, use compound interest calculations "
201
+ "based on historical averages. Avoid repetition and incomplete answers. Explain why each step or choice is beneficial.\n\n"
202
+ "Example 1:\n"
203
+ "Q: What is the S&P 500’s average annual return?\n"
204
+ "A: The S&P 500’s average annual return is ~10–12% over the long term (1927–2025), including dividends.\n"
205
+ "1. This reflects historical data adjusted for inflation and dividends.\n"
206
+ "2. Returns vary yearly (e.g., 16.3% in 2020) due to market conditions.\n"
207
+ "3. ETFs like SPY track this index for broad market exposure.\n\n"
208
+ "Example 2:\n"
209
+ "Q: What will $5,000 be worth in 10 years if invested in TSLA?\n"
210
+ "A: Assuming a 10% average annual return, a $5,000 investment in TSLA would grow to approximately $12,974 in 10 years with annual compounding.\n"
211
+ "1. This uses the historical average return of 10–12% for stocks.\n"
212
+ "2. Future returns vary and are not guaranteed.\n\n"
213
+ "Example 3:\n"
214
+ "Q: What was the average annual return of MSFT between 2010 and 2020?\n"
215
+ "A: The MSFT average annual return from 2010 to 2020 was approximately 16.8%, including dividends.\n"
216
+ "1. This period includes strong growth in tech stocks.\n"
217
+ "2. Dividends contribute significantly to total returns.\n\n"
218
+ "Q: "
219
+ )
220
+ prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
221
+
222
+ # Substring matching for cache
223
+ def get_closest_cache_key(message, cache_keys):
224
+ message = message.lower().strip()
225
+ matches = difflib.get_close_matches(message, cache_keys, n=1, cutoff=0.8)
226
+ return matches[0] if matches else None
227
+
228
+ # Parse period from user input
229
+ def parse_period(query):
230
+ query = query.lower()
231
+ # Match symbol (TSLA, MSFT, NVDA, GOOG, AMZN, S&P 500)
232
+ symbol_match = re.search(r'(tsla|msft|nvda|goog|amzn|s&p\s*500)', query)
233
+ symbol = symbol_match.group(1).upper() if symbol_match else "SPY"
234
+ if symbol == "S&P 500":
235
+ symbol = "SPY"
236
+ # Match specific year ranges
237
+ match = re.search(r'(?:between|from)\s*(\d{4})\s*(?:and|to|-|–)\s*(\d{4})', query)
238
+ if match:
239
+ start_year, end_year = map(int, match.groups())
240
+ if start_year <= end_year:
241
+ return start_year, end_year, None, symbol
242
+ # Match duration-based queries
243
+ match = re.search(r'(\d+)-year.*from\s*(\d{4})', query)
244
+ if match:
245
+ duration, start_year = map(int, match.groups())
246
+ end_year = start_year + duration - 1
247
+ return start_year, end_year, duration, symbol
248
+ # Match past X years
249
+ match = re.search(r'(?:past\s*(\d+)-year|\b(\d+)-year.*(?:return|growth\s*rate))', query)
250
+ if match:
251
+ duration = int(match.group(1) or match.group(2))
252
+ max_year = df_yearly['Year'].max() if df_yearly is not None else 2025
253
+ start_year = max_year - duration + 1
254
+ end_year = max_year
255
+ return start_year, end_year, duration, symbol
256
+ # Match single year
257
+ match = re.search(r'return\s*(?:in|for)\s*(\d{4})', query)
258
+ if match:
259
+ year = int(match.group(1))
260
+ return year, year, 1, symbol
261
+ return None, None, None, symbol
262
+
263
+ # Calculate average growth rate
264
+ def calculate_growth_rate(start_year, end_year, duration=None, symbol="SPY"):
265
+ if df_yearly is None or start_year is None or end_year is None:
266
+ return None, "Data not available or invalid period."
267
+ df_period = df_yearly[(df_yearly['Year'] >= start_year) & (df_yearly['Year'] <= end_year)]
268
+ if df_period.empty:
269
+ return None, f"No data available for {symbol} from {start_year} to {end_year}."
270
+ avg_return = df_period[f"Return_{symbol}"].mean()
271
+ if np.isnan(avg_return):
272
+ return None, f"Insufficient data for {symbol} from {start_year} to {end_year}."
273
+ symbol_name = "S&P 500" if symbol == "SPY" else symbol
274
+ if duration == 1 and start_year == end_year:
275
+ response = f"The {symbol_name} returned approximately {avg_return:.1f}% in {start_year}, including dividends."
276
+ elif duration:
277
+ response = f"The {symbol_name} {duration}-year average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
278
+ else:
279
+ response = f"The {symbol_name} average annual return from {start_year} to {end_year} was approximately {avg_return:.1f}%, including dividends."
280
+ return avg_return, response
281
+
282
+ # Parse investment return query
283
+ def parse_investment_query(query):
284
+ match = re.search(r'\$(\d+).*\s(\d+)\s*years?.*\b(tsla|msft|nvda|goog|amzn|s&p\s*500)\b', query, re.IGNORECASE)
285
+ if match:
286
+ amount = float(match.group(1))
287
+ years = int(match.group(2))
288
+ symbol = match.group(3).upper()
289
+ if symbol == "S&P 500":
290
+ symbol = "SPY"
291
+ return amount, years, symbol
292
+ return None, None, None
293
+
294
+ # Calculate future value
295
+ def calculate_future_value(amount, years, symbol):
296
+ if df_yearly is None or amount is None or years is None:
297
+ return None, "Data not available or invalid input."
298
+ avg_annual_return = 10.0
299
+ future_value = amount * (1 + avg_annual_return / 100) ** years
300
+ symbol_name = "S&P 500" if symbol == "SPY" else symbol
301
+ return future_value, (
302
+ f"Assuming a 10% average annual return, a ${amount:,.0f} investment in {symbol_name} would grow to approximately ${future_value:,.0f} "
303
+ f"in {years} years with annual compounding. This is based on the historical average return of 10–12% for stocks. "
304
+ "Future returns vary and are not guaranteed. Consult a financial planner."
305
+ )
306
+
307
+ # Chat function
308
+ def chat_with_model(user_input, history=None, is_processing=False):
309
+ try:
310
+ start_time = time.time()
311
+ logger.info(f"Processing user input: {user_input}")
312
+ is_processing = True
313
+ logger.info("Showing loading animation")
314
+
315
+ # Normalize and check cache
316
+ cache_key = user_input.lower().strip()
317
+ cache_keys = list(response_cache.keys())
318
+ closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
319
+ if closest_key:
320
+ logger.info(f"Cache hit for: {closest_key}")
321
+ response = response_cache[closest_key]
322
+ logger.info(f"Chatbot response: {response}")
323
+ history = history or []
324
+ history.append({"role": "user", "content": user_input})
325
+ history.append({"role": "assistant", "content": response})
326
+ end_time = time.time()
327
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
328
+ return response, history, False, ""
329
+
330
+ # Check for investment return query
331
+ amount, years, symbol = parse_investment_query(user_input)
332
+ if amount and years:
333
+ future_value, response = calculate_future_value(amount, years, symbol)
334
+ if future_value is not None:
335
+ response_cache[cache_key] = response
336
+ logger.info(f"Investment query: ${amount} for {years} years in {symbol}, added to cache")
337
+ logger.info(f"Chatbot response: {response}")
338
+ history = history or []
339
+ history.append({"role": "user", "content": user_input})
340
+ history.append({"role": "assistant", "content": response})
341
+ end_time = time.time()
342
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
343
+ return response, history, False, ""
344
+
345
+ # Check for period-specific query
346
+ start_year, end_year, duration, symbol = parse_period(user_input)
347
+ if start_year and end_year:
348
+ avg_return, response = calculate_growth_rate(start_year, end_year, duration, symbol)
349
+ if avg_return is not None:
350
+ response_cache[cache_key] = response
351
+ logger.info(f"Dynamic period query for {symbol}: {start_year}–{end_year}, added to cache")
352
+ logger.info(f"Chatbot response: {response}")
353
+ history = history or []
354
+ history.append({"role": "user", "content": user_input})
355
+ history.append({"role": "assistant", "content": response})
356
+ end_time = time.time()
357
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
358
+ return response, history, False, ""
359
+
360
+ # Handle short prompts
361
+ if len(user_input.strip()) <= 5:
362
+ logger.info("Short prompt, returning default response")
363
+ response = "Hello! I'm FinChat, your financial advisor. Ask about investing in TSLA, MSFT, NVDA, GOOG, AMZN, or S&P 500!"
364
+ logger.info(f"Chatbot response: {response}")
365
+ history = history or []
366
+ history.append({"role": "user", "content": user_input})
367
+ history.append({"role": "assistant", "content": response})
368
+ end_time = time.time()
369
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
370
+ return response, history, False, ""
371
+
372
+ # Construct and generate response
373
+ full_prompt = prompt_prefix + user_input + "\nA:"
374
+ try:
375
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
376
+ except Exception as e:
377
+ logger.error(f"Error tokenizing input: {e}")
378
+ response = f"Error: Failed to process input: {str(e)}"
379
+ logger.info(f"Chatbot response: {response}")
380
+ history = history or []
381
+ history.append({"role": "user", "content": user_input})
382
+ history.append({"role": "assistant", "content": response})
383
+ end_time = time.time()
384
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
385
+ return response, history, False, ""
386
+
387
+ with torch.inference_mode():
388
+ logger.info("Generating response with model")
389
+ gen_start_time = time.time()
390
+ outputs = model.generate(
391
+ **inputs,
392
+ max_new_tokens=40,
393
+ min_length=20,
394
+ do_sample=False,
395
+ repetition_penalty=2.0,
396
+ pad_token_id=tokenizer.eos_token_id
397
+ )
398
+ gen_end_time = time.time()
399
+ logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds")
400
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
401
+ response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
402
+ logger.info(f"Chatbot response: {response}")
403
+
404
+ # Update cache
405
+ response_cache[cache_key] = response
406
+ logger.info("Cache miss, added to in-memory cache")
407
+
408
+ # Update history
409
+ history = history or []
410
+ history.append({"role": "user", "content": user_input})
411
+ history.append({"role": "assistant", "content": response})
412
+ torch.cuda.empty_cache()
413
+ end_time = time.time()
414
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
415
+ return response, history, False, ""
416
+
417
+ except Exception as e:
418
+ logger.error(f"Error generating response: {e}")
419
+ response = f"Error: {str(e)}"
420
+ logger.info(f"Chatbot response: {response}")
421
+ history = history or []
422
+ history.append({"role": "user", "content": user_input})
423
+ history.append({"role": "assistant", "content": response})
424
+ end_time = time.time()
425
+ logger.info(f"Response time: {end_time - start_time:.2f} seconds")
426
+ return response, history, False, ""
427
+
428
+ # Save cache
429
+ def save_cache():
430
+ try:
431
+ with open(cache_file, 'w') as f:
432
+ json.dump(response_cache, f, indent=2)
433
+ logger.info("Saved cache to cache.json")
434
+ except Exception as e:
435
+ logger.warning(f"Failed to save cache.json: {e}")
436
+
437
+ # Gradio interface
438
+ logger.info("Initializing Gradio interface")
439
+ try:
440
+ with gr.Blocks(
441
+ title="FinChat: An LLM based on distilgpt2 model",
442
+ css="""
443
+ .loader {
444
+ border: 5px solid #f3f3f3;
445
+ border-top: 5px solid #3498db;
446
+ border-radius: 50%;
447
+ width: 30px;
448
+ height: 30px;
449
+ animation: spin 1s linear infinite;
450
+ margin: 10px auto;
451
+ display: block;
452
+ }
453
+ @keyframes spin {
454
+ 0% { transform: rotate(0deg); }
455
+ 100% { transform: rotate(360deg); }
456
+ }
457
+ .hidden { display: none; }
458
+ """
459
+ ) as interface:
460
+ gr.Markdown(
461
+ """
462
+ # FinChat: An LLM based on distilgpt2 model
463
+ FinChat provides financial advice using the lightweight distilgpt2 model, optimized for fast, detailed responses.
464
+ Ask about investing strategies, ETFs, or stocks like TSLA, MSFT, NVDA, GOOG, AMZN, or S&P 500 to get started!
465
+ """
466
+ )
467
+ chatbot = gr.Chatbot(type="messages")
468
+ msg = gr.Textbox(label="Your message")
469
+ submit = gr.Button("Send")
470
+ clear = gr.Button("Clear")
471
+ loading = gr.HTML('<div class="loader hidden"></div>', label="Loading")
472
+ is_processing = gr.State(value=False)
473
+
474
+ def submit_message(user_input, history, is_processing):
475
+ response, updated_history, new_processing, clear_input = chat_with_model(user_input, history, is_processing)
476
+ loader_html = '<div class="loader"></div>' if new_processing else '<div class="loader hidden"></div>'
477
+ return clear_input, updated_history, loader_html, new_processing
478
+
479
+ submit.click(
480
+ fn=submit_message,
481
+ inputs=[msg, chatbot, is_processing],
482
+ outputs=[msg, chatbot, loading, is_processing]
483
+ )
484
+ clear.click(
485
+ fn=lambda: ("", [], '<div class="loader hidden"></div>', False),
486
+ outputs=[msg, chatbot, loading, is_processing]
487
+ )
488
+ logger.info("Gradio interface initialized successfully")
489
+ except Exception as e:
490
+ logger.error(f"Error initializing Gradio interface: {e}")
491
+ raise
492
+
493
+ # Launch interface
494
+ if __name__ == "__main__" and not os.getenv("HF_SPACE"):
495
+ logger.info("Launching Gradio interface locally")
496
+ try:
497
+ interface.launch(share=False, debug=True)
498
+ except Exception as e:
499
+ logger.error(f"Error launching interface: {e}")
500
+ raise
501
+ finally:
502
+ save_cache()
503
+ else:
504
+ logger.info("Running in Hugging Face Spaces, interface defined but not launched")
505
+ import atexit
506
+ atexit.register(save_cache)