Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import re | |
from datetime import datetime, timedelta | |
import difflib | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import yfinance as yf | |
from functools import lru_cache | |
import pandas as pd | |
# Define the list of tickers | |
tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT'] | |
# Company name to ticker mapping for better query handling | |
company_to_ticker = { | |
'tesla': 'TSLA', | |
'palantir': 'PLTR', | |
'soundhound': 'SOUN', | |
'microsoft': 'MSFT', | |
'nvidia': 'NVDA', | |
'google': 'GOOG', | |
'amazon': 'AMZN', | |
'apple': 'AAPL', | |
'meta': 'META', | |
'netflix': 'NFLX', | |
'intel': 'INTC', | |
'amd': 'AMD', | |
'ibm': 'IBM', | |
'oracle': 'ORCL', | |
'cisco': 'CSCO', | |
'jpmorgan': 'JPM', | |
'bank of america': 'BAC', | |
'wells fargo': 'WFC', | |
'visa': 'V', | |
'mastercard': 'MA', | |
'exxon': 'XOM', | |
'chevron': 'CVX', | |
'pfizer': 'PFE', | |
'johnson & johnson': 'JNJ', | |
'merck': 'MRK', | |
'spy': 'SPY' | |
} | |
# Prefetch stock data for all tickers at startup using yfinance | |
all_data = {} | |
try: | |
now = datetime.now().strftime('%Y-%m-%d') | |
for ticker in tickers: | |
all_data[ticker] = yf.download(ticker, start='2020-01-01', end=now, auto_adjust=True) | |
except Exception as e: | |
print(f"Error prefetching data: {e}") | |
all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure | |
# Create a DataFrame with 'Adj Close' columns for each ticker | |
series_list = [] | |
for ticker, data in all_data.items(): | |
if not data.empty: | |
s = data['Close'] | |
s.name = ticker | |
series_list.append(s) | |
adj_close_data = pd.concat(series_list, axis=1) if series_list else pd.DataFrame() | |
# Display the first few rows to verify (for debugging; remove in production) | |
print(adj_close_data.head()) | |
# Update available symbols to include new tickers | |
available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'META', 'NFLX', 'INTC', 'AMD', 'IBM', 'ORCL', 'CSCO', 'JPM', 'BAC', 'WFC', 'V', 'MA', 'XOM', 'CVX', 'PFE', 'JNJ', 'MRK', 'PLTR', 'SOUN'] | |
def fetch_stock_data(symbol, start_date, end_date): | |
if symbol in all_data and not all_data[symbol].empty: | |
# Use preloaded data and slice by date | |
hist = all_data[symbol] | |
return hist[(hist.index >= start_date) & (hist.index <= end_date)] | |
else: | |
# Fetch on-demand with yfinance | |
try: | |
ticker = yf.Ticker(symbol) | |
hist = ticker.history(start=start_date, end=end_date, auto_adjust=True) | |
return hist | |
except Exception as e: | |
print(f"Error fetching data for {symbol}: {e}") | |
return None | |
def parse_dates(query): | |
# Handle year ranges like "between 2010 and 2020" | |
range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower()) | |
if range_match: | |
start_year = int(range_match.group(1)) | |
end_year = int(range_match.group(2)) | |
try: | |
start_date = datetime(start_year, 1, 1) | |
end_date = datetime(end_year, 12, 31) | |
if start_date >= end_date: | |
raise ValueError("Start date must be before end date") | |
return start_date, end_date | |
except ValueError as e: | |
print(f"Date parsing error: {e}") | |
return None, None | |
# Fallback to period parsing for recent periods | |
period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower()) | |
if period_match: | |
num = int(period_match.group(1)) | |
unit = period_match.group(2) | |
if unit == 'year': | |
period = timedelta(days=365 * num) | |
elif unit == 'month': | |
period = timedelta(days=30 * num) | |
elif unit == 'week': | |
period = timedelta(weeks=num) | |
elif unit == 'day': | |
period = timedelta(days=num) | |
else: | |
period = timedelta(days=365) | |
end_date = datetime.now() | |
start_date = end_date - period | |
return start_date, end_date | |
# Default to 1 year | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=365) | |
return start_date, end_date | |
def find_closest_symbol(input_symbol): | |
input_symbol = input_symbol.upper() | |
# Check if input matches a company name | |
for company, ticker in company_to_ticker.items(): | |
if company in input_symbol.lower(): | |
return ticker | |
# Fallback to ticker matching | |
closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6) | |
return closest[0] if closest else None | |
def calculate_growth_rate(start_date, end_date, symbol): | |
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d')) | |
if hist is |