File size: 4,748 Bytes
a0c1901
08d63f2
25a1d3e
 
5d661b6
e2b2a4b
25a1d3e
c3e2fa0
dbf541c
9d7f426
f7cc8c3
9d7f426
 
 
ff3b207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3e2fa0
9d7f426
 
 
 
c3e2fa0
9d7f426
c3e2fa0
 
9d7f426
ff3b207
9c66ae0
 
 
 
 
 
ff3b207
9d7f426
 
 
 
 
 
f7cc8c3
dbf541c
 
c3e2fa0
9d7f426
 
 
 
c3e2fa0
9d7f426
 
c3e2fa0
9d7f426
 
 
 
dbf541c
d0ea0a8
ff3b207
c3e2fa0
 
 
 
ff3b207
 
 
 
 
 
 
 
 
 
d0ea0a8
 
 
 
25a1d3e
d0ea0a8
25a1d3e
d0ea0a8
25a1d3e
d0ea0a8
25a1d3e
d0ea0a8
 
 
 
 
 
 
 
 
 
25a1d3e
 
 
ff3b207
 
 
 
 
25a1d3e
 
 
c248884
dbf541c
ff3b207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import gradio as gr
import numpy as np
import re
from datetime import datetime, timedelta
import difflib
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import yfinance as yf
from functools import lru_cache
import pandas as pd

# Define the list of tickers
tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']

# Company name to ticker mapping for better query handling
company_to_ticker = {
    'tesla': 'TSLA',
    'palantir': 'PLTR',
    'soundhound': 'SOUN',
    'microsoft': 'MSFT',
    'nvidia': 'NVDA',
    'google': 'GOOG',
    'amazon': 'AMZN',
    'apple': 'AAPL',
    'meta': 'META',
    'netflix': 'NFLX',
    'intel': 'INTC',
    'amd': 'AMD',
    'ibm': 'IBM',
    'oracle': 'ORCL',
    'cisco': 'CSCO',
    'jpmorgan': 'JPM',
    'bank of america': 'BAC',
    'wells fargo': 'WFC',
    'visa': 'V',
    'mastercard': 'MA',
    'exxon': 'XOM',
    'chevron': 'CVX',
    'pfizer': 'PFE',
    'johnson & johnson': 'JNJ',
    'merck': 'MRK',
    'spy': 'SPY'
}

# Prefetch stock data for all tickers at startup using yfinance
all_data = {}
try:
    now = datetime.now().strftime('%Y-%m-%d')
    for ticker in tickers:
        all_data[ticker] = yf.download(ticker, start='2020-01-01', end=now, auto_adjust=True)
except Exception as e:
    print(f"Error prefetching data: {e}")
    all_data = {ticker: pd.DataFrame() for ticker in tickers}  # Initialize empty DataFrames on failure

# Create a DataFrame with 'Adj Close' columns for each ticker
series_list = []
for ticker, data in all_data.items():
    if not data.empty:
        s = data['Close']
        s.name = ticker
        series_list.append(s)
adj_close_data = pd.concat(series_list, axis=1) if series_list else pd.DataFrame()

# Display the first few rows to verify (for debugging; remove in production)
print(adj_close_data.head())

# Update available symbols to include new tickers
available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'META', 'NFLX', 'INTC', 'AMD', 'IBM', 'ORCL', 'CSCO', 'JPM', 'BAC', 'WFC', 'V', 'MA', 'XOM', 'CVX', 'PFE', 'JNJ', 'MRK', 'PLTR', 'SOUN']

@lru_cache(maxsize=100)
def fetch_stock_data(symbol, start_date, end_date):
    if symbol in all_data and not all_data[symbol].empty:
        # Use preloaded data and slice by date
        hist = all_data[symbol]
        return hist[(hist.index >= start_date) & (hist.index <= end_date)]
    else:
        # Fetch on-demand with yfinance
        try:
            ticker = yf.Ticker(symbol)
            hist = ticker.history(start=start_date, end=end_date, auto_adjust=True)
            return hist
        except Exception as e:
            print(f"Error fetching data for {symbol}: {e}")
            return None

def parse_dates(query):
    # Handle year ranges like "between 2010 and 2020"
    range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
    if range_match:
        start_year = int(range_match.group(1))
        end_year = int(range_match.group(2))
        try:
            start_date = datetime(start_year, 1, 1)
            end_date = datetime(end_year, 12, 31)
            if start_date >= end_date:
                raise ValueError("Start date must be before end date")
            return start_date, end_date
        except ValueError as e:
            print(f"Date parsing error: {e}")
            return None, None
    # Fallback to period parsing for recent periods
    period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
    if period_match:
        num = int(period_match.group(1))
        unit = period_match.group(2)
        if unit == 'year':
            period = timedelta(days=365 * num)
        elif unit == 'month':
            period = timedelta(days=30 * num)
        elif unit == 'week':
            period = timedelta(weeks=num)
        elif unit == 'day':
            period = timedelta(days=num)
        else:
            period = timedelta(days=365)
        end_date = datetime.now()
        start_date = end_date - period
        return start_date, end_date
    # Default to 1 year
    end_date = datetime.now()
    start_date = end_date - timedelta(days=365)
    return start_date, end_date

def find_closest_symbol(input_symbol):
    input_symbol = input_symbol.upper()
    # Check if input matches a company name
    for company, ticker in company_to_ticker.items():
        if company in input_symbol.lower():
            return ticker
    # Fallback to ticker matching
    closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
    return closest[0] if closest else None

def calculate_growth_rate(start_date, end_date, symbol):
    hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
    if hist is