FinChat / app.py
AnilNiraula's picture
Update app.py
ff3b207 verified
raw
history blame
4.75 kB
import gradio as gr
import numpy as np
import re
from datetime import datetime, timedelta
import difflib
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import yfinance as yf
from functools import lru_cache
import pandas as pd
# Define the list of tickers
tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
# Company name to ticker mapping for better query handling
company_to_ticker = {
'tesla': 'TSLA',
'palantir': 'PLTR',
'soundhound': 'SOUN',
'microsoft': 'MSFT',
'nvidia': 'NVDA',
'google': 'GOOG',
'amazon': 'AMZN',
'apple': 'AAPL',
'meta': 'META',
'netflix': 'NFLX',
'intel': 'INTC',
'amd': 'AMD',
'ibm': 'IBM',
'oracle': 'ORCL',
'cisco': 'CSCO',
'jpmorgan': 'JPM',
'bank of america': 'BAC',
'wells fargo': 'WFC',
'visa': 'V',
'mastercard': 'MA',
'exxon': 'XOM',
'chevron': 'CVX',
'pfizer': 'PFE',
'johnson & johnson': 'JNJ',
'merck': 'MRK',
'spy': 'SPY'
}
# Prefetch stock data for all tickers at startup using yfinance
all_data = {}
try:
now = datetime.now().strftime('%Y-%m-%d')
for ticker in tickers:
all_data[ticker] = yf.download(ticker, start='2020-01-01', end=now, auto_adjust=True)
except Exception as e:
print(f"Error prefetching data: {e}")
all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
# Create a DataFrame with 'Adj Close' columns for each ticker
series_list = []
for ticker, data in all_data.items():
if not data.empty:
s = data['Close']
s.name = ticker
series_list.append(s)
adj_close_data = pd.concat(series_list, axis=1) if series_list else pd.DataFrame()
# Display the first few rows to verify (for debugging; remove in production)
print(adj_close_data.head())
# Update available symbols to include new tickers
available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'META', 'NFLX', 'INTC', 'AMD', 'IBM', 'ORCL', 'CSCO', 'JPM', 'BAC', 'WFC', 'V', 'MA', 'XOM', 'CVX', 'PFE', 'JNJ', 'MRK', 'PLTR', 'SOUN']
@lru_cache(maxsize=100)
def fetch_stock_data(symbol, start_date, end_date):
if symbol in all_data and not all_data[symbol].empty:
# Use preloaded data and slice by date
hist = all_data[symbol]
return hist[(hist.index >= start_date) & (hist.index <= end_date)]
else:
# Fetch on-demand with yfinance
try:
ticker = yf.Ticker(symbol)
hist = ticker.history(start=start_date, end=end_date, auto_adjust=True)
return hist
except Exception as e:
print(f"Error fetching data for {symbol}: {e}")
return None
def parse_dates(query):
# Handle year ranges like "between 2010 and 2020"
range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
if range_match:
start_year = int(range_match.group(1))
end_year = int(range_match.group(2))
try:
start_date = datetime(start_year, 1, 1)
end_date = datetime(end_year, 12, 31)
if start_date >= end_date:
raise ValueError("Start date must be before end date")
return start_date, end_date
except ValueError as e:
print(f"Date parsing error: {e}")
return None, None
# Fallback to period parsing for recent periods
period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
if period_match:
num = int(period_match.group(1))
unit = period_match.group(2)
if unit == 'year':
period = timedelta(days=365 * num)
elif unit == 'month':
period = timedelta(days=30 * num)
elif unit == 'week':
period = timedelta(weeks=num)
elif unit == 'day':
period = timedelta(days=num)
else:
period = timedelta(days=365)
end_date = datetime.now()
start_date = end_date - period
return start_date, end_date
# Default to 1 year
end_date = datetime.now()
start_date = end_date - timedelta(days=365)
return start_date, end_date
def find_closest_symbol(input_symbol):
input_symbol = input_symbol.upper()
# Check if input matches a company name
for company, ticker in company_to_ticker.items():
if company in input_symbol.lower():
return ticker
# Fallback to ticker matching
closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
return closest[0] if closest else None
def calculate_growth_rate(start_date, end_date, symbol):
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
if hist is