Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,36 @@ import pandas as pd
|
|
12 |
# Define the list of tickers
|
13 |
tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Prefetch stock data for all tickers at startup using yfinance
|
16 |
all_data = {}
|
17 |
try:
|
@@ -22,14 +52,14 @@ except Exception as e:
|
|
22 |
print(f"Error prefetching data: {e}")
|
23 |
all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
|
24 |
|
25 |
-
# Create a DataFrame with 'Adj Close' columns for each ticker
|
26 |
series_list = []
|
27 |
for ticker, data in all_data.items():
|
28 |
if not data.empty:
|
29 |
s = data['Close']
|
30 |
s.name = ticker
|
31 |
series_list.append(s)
|
32 |
-
adj_close_data = pd.concat(series_list, axis=1)
|
33 |
|
34 |
# Display the first few rows to verify (for debugging; remove in production)
|
35 |
print(adj_close_data.head())
|
@@ -54,13 +84,21 @@ def fetch_stock_data(symbol, start_date, end_date):
|
|
54 |
return None
|
55 |
|
56 |
def parse_dates(query):
|
57 |
-
#
|
58 |
range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
|
59 |
if range_match:
|
60 |
start_year = int(range_match.group(1))
|
61 |
end_year = int(range_match.group(2))
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
|
65 |
if period_match:
|
66 |
num = int(period_match.group(1))
|
@@ -85,72 +123,14 @@ def parse_dates(query):
|
|
85 |
|
86 |
def find_closest_symbol(input_symbol):
|
87 |
input_symbol = input_symbol.upper()
|
|
|
|
|
|
|
|
|
|
|
88 |
closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
|
89 |
return closest[0] if closest else None
|
90 |
|
91 |
def calculate_growth_rate(start_date, end_date, symbol):
|
92 |
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
|
93 |
-
if hist is
|
94 |
-
return None
|
95 |
-
beginning_value = hist.iloc[0]['Close']
|
96 |
-
ending_value = hist.iloc[-1]['Close']
|
97 |
-
years = (end_date - start_date).days / 365.25
|
98 |
-
if years <= 0:
|
99 |
-
return 0
|
100 |
-
total_return = ending_value / beginning_value
|
101 |
-
cagr = total_return ** (1 / years) - 1
|
102 |
-
return cagr * 100
|
103 |
-
|
104 |
-
def calculate_investment(principal, years, annual_return=0.07):
|
105 |
-
return principal * (1 + annual_return) ** years
|
106 |
-
|
107 |
-
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
|
108 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
109 |
-
model = AutoModelForCausalLM.from_pretrained(
|
110 |
-
model_name,
|
111 |
-
device_map="auto",
|
112 |
-
)
|
113 |
-
|
114 |
-
def generate_response(user_query, enable_thinking=False):
|
115 |
-
stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
|
116 |
-
is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
|
117 |
-
summary = ""
|
118 |
-
if is_stock_query:
|
119 |
-
symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
|
120 |
-
symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None
|
121 |
-
if symbol:
|
122 |
-
start_date, end_date = parse_dates(user_query)
|
123 |
-
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
|
124 |
-
if "average price" in user_query.lower():
|
125 |
-
if hist is not None and not hist.empty and 'Close' in hist.columns:
|
126 |
-
avg_price = hist['Close'].mean()
|
127 |
-
summary = f"The average adjusted closing price for {symbol} from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is ${avg_price:.2f}."
|
128 |
-
elif "cagr" in user_query.lower() or "return" in user_query.lower():
|
129 |
-
growth_rate = calculate_growth_rate(start_date, end_date, symbol)
|
130 |
-
if growth_rate is not None:
|
131 |
-
summary = f"The CAGR for {symbol} over the period from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is {growth_rate:.2f}%."
|
132 |
-
else:
|
133 |
-
summary = f"No data available for {symbol} in the specified period."
|
134 |
-
investment_match = re.search(r'\$(\d+)', user_query)
|
135 |
-
if investment_match:
|
136 |
-
principal = float(investment_match.group(1))
|
137 |
-
years = (end_date - start_date).days / 365.25
|
138 |
-
projected = calculate_investment(principal, years)
|
139 |
-
summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}."
|
140 |
-
system_prompt = (
|
141 |
-
"You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
|
142 |
-
"For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
|
143 |
-
"Provide accurate, concise advice based on data."
|
144 |
-
)
|
145 |
-
# Placeholder for generation logic (tokenize, generate, decode)
|
146 |
-
return summary or "Please provide a specific stock or investment query."
|
147 |
-
|
148 |
-
# Gradio interface setup
|
149 |
-
demo = gr.Interface(
|
150 |
-
fn=generate_response,
|
151 |
-
inputs=[gr.Textbox(lines=2, placeholder="Enter your query (e.g., 'TSLA CAGR between 2010 and 2020')"), gr.Checkbox(label="Enable Thinking")],
|
152 |
-
outputs="text",
|
153 |
-
title="FinChat",
|
154 |
-
description="Ask about stock performance, CAGR, or investments."
|
155 |
-
)
|
156 |
-
demo.launch()
|
|
|
12 |
# Define the list of tickers
|
13 |
tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
|
14 |
|
15 |
+
# Company name to ticker mapping for better query handling
|
16 |
+
company_to_ticker = {
|
17 |
+
'tesla': 'TSLA',
|
18 |
+
'palantir': 'PLTR',
|
19 |
+
'soundhound': 'SOUN',
|
20 |
+
'microsoft': 'MSFT',
|
21 |
+
'nvidia': 'NVDA',
|
22 |
+
'google': 'GOOG',
|
23 |
+
'amazon': 'AMZN',
|
24 |
+
'apple': 'AAPL',
|
25 |
+
'meta': 'META',
|
26 |
+
'netflix': 'NFLX',
|
27 |
+
'intel': 'INTC',
|
28 |
+
'amd': 'AMD',
|
29 |
+
'ibm': 'IBM',
|
30 |
+
'oracle': 'ORCL',
|
31 |
+
'cisco': 'CSCO',
|
32 |
+
'jpmorgan': 'JPM',
|
33 |
+
'bank of america': 'BAC',
|
34 |
+
'wells fargo': 'WFC',
|
35 |
+
'visa': 'V',
|
36 |
+
'mastercard': 'MA',
|
37 |
+
'exxon': 'XOM',
|
38 |
+
'chevron': 'CVX',
|
39 |
+
'pfizer': 'PFE',
|
40 |
+
'johnson & johnson': 'JNJ',
|
41 |
+
'merck': 'MRK',
|
42 |
+
'spy': 'SPY'
|
43 |
+
}
|
44 |
+
|
45 |
# Prefetch stock data for all tickers at startup using yfinance
|
46 |
all_data = {}
|
47 |
try:
|
|
|
52 |
print(f"Error prefetching data: {e}")
|
53 |
all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
|
54 |
|
55 |
+
# Create a DataFrame with 'Adj Close' columns for each ticker
|
56 |
series_list = []
|
57 |
for ticker, data in all_data.items():
|
58 |
if not data.empty:
|
59 |
s = data['Close']
|
60 |
s.name = ticker
|
61 |
series_list.append(s)
|
62 |
+
adj_close_data = pd.concat(series_list, axis=1) if series_list else pd.DataFrame()
|
63 |
|
64 |
# Display the first few rows to verify (for debugging; remove in production)
|
65 |
print(adj_close_data.head())
|
|
|
84 |
return None
|
85 |
|
86 |
def parse_dates(query):
|
87 |
+
# Handle year ranges like "between 2010 and 2020"
|
88 |
range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
|
89 |
if range_match:
|
90 |
start_year = int(range_match.group(1))
|
91 |
end_year = int(range_match.group(2))
|
92 |
+
try:
|
93 |
+
start_date = datetime(start_year, 1, 1)
|
94 |
+
end_date = datetime(end_year, 12, 31)
|
95 |
+
if start_date >= end_date:
|
96 |
+
raise ValueError("Start date must be before end date")
|
97 |
+
return start_date, end_date
|
98 |
+
except ValueError as e:
|
99 |
+
print(f"Date parsing error: {e}")
|
100 |
+
return None, None
|
101 |
+
# Fallback to period parsing for recent periods
|
102 |
period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
|
103 |
if period_match:
|
104 |
num = int(period_match.group(1))
|
|
|
123 |
|
124 |
def find_closest_symbol(input_symbol):
|
125 |
input_symbol = input_symbol.upper()
|
126 |
+
# Check if input matches a company name
|
127 |
+
for company, ticker in company_to_ticker.items():
|
128 |
+
if company in input_symbol.lower():
|
129 |
+
return ticker
|
130 |
+
# Fallback to ticker matching
|
131 |
closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
|
132 |
return closest[0] if closest else None
|
133 |
|
134 |
def calculate_growth_rate(start_date, end_date, symbol):
|
135 |
hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
|
136 |
+
if hist is
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|