AnilNiraula commited on
Commit
ff3b207
·
verified ·
1 Parent(s): 9c66ae0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -69
app.py CHANGED
@@ -12,6 +12,36 @@ import pandas as pd
12
  # Define the list of tickers
13
  tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Prefetch stock data for all tickers at startup using yfinance
16
  all_data = {}
17
  try:
@@ -22,14 +52,14 @@ except Exception as e:
22
  print(f"Error prefetching data: {e}")
23
  all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
24
 
25
- # Create a DataFrame with 'Adj Close' columns for each ticker using pd.concat to handle varying lengths
26
  series_list = []
27
  for ticker, data in all_data.items():
28
  if not data.empty:
29
  s = data['Close']
30
  s.name = ticker
31
  series_list.append(s)
32
- adj_close_data = pd.concat(series_list, axis=1)
33
 
34
  # Display the first few rows to verify (for debugging; remove in production)
35
  print(adj_close_data.head())
@@ -54,13 +84,21 @@ def fetch_stock_data(symbol, start_date, end_date):
54
  return None
55
 
56
  def parse_dates(query):
57
- # Enhanced to handle year ranges like "between 2010 and 2020"
58
  range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
59
  if range_match:
60
  start_year = int(range_match.group(1))
61
  end_year = int(range_match.group(2))
62
- return datetime(start_year, 1, 1), datetime(end_year, 12, 31)
63
- # Fallback to original period parsing for recent periods
 
 
 
 
 
 
 
 
64
  period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
65
  if period_match:
66
  num = int(period_match.group(1))
@@ -85,72 +123,14 @@ def parse_dates(query):
85
 
86
  def find_closest_symbol(input_symbol):
87
  input_symbol = input_symbol.upper()
 
 
 
 
 
88
  closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
89
  return closest[0] if closest else None
90
 
91
  def calculate_growth_rate(start_date, end_date, symbol):
92
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
93
- if hist is None or hist.empty or 'Close' not in hist.columns:
94
- return None
95
- beginning_value = hist.iloc[0]['Close']
96
- ending_value = hist.iloc[-1]['Close']
97
- years = (end_date - start_date).days / 365.25
98
- if years <= 0:
99
- return 0
100
- total_return = ending_value / beginning_value
101
- cagr = total_return ** (1 / years) - 1
102
- return cagr * 100
103
-
104
- def calculate_investment(principal, years, annual_return=0.07):
105
- return principal * (1 + annual_return) ** years
106
-
107
- model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
108
- tokenizer = AutoTokenizer.from_pretrained(model_name)
109
- model = AutoModelForCausalLM.from_pretrained(
110
- model_name,
111
- device_map="auto",
112
- )
113
-
114
- def generate_response(user_query, enable_thinking=False):
115
- stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
116
- is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
117
- summary = ""
118
- if is_stock_query:
119
- symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
120
- symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None
121
- if symbol:
122
- start_date, end_date = parse_dates(user_query)
123
- hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
124
- if "average price" in user_query.lower():
125
- if hist is not None and not hist.empty and 'Close' in hist.columns:
126
- avg_price = hist['Close'].mean()
127
- summary = f"The average adjusted closing price for {symbol} from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is ${avg_price:.2f}."
128
- elif "cagr" in user_query.lower() or "return" in user_query.lower():
129
- growth_rate = calculate_growth_rate(start_date, end_date, symbol)
130
- if growth_rate is not None:
131
- summary = f"The CAGR for {symbol} over the period from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is {growth_rate:.2f}%."
132
- else:
133
- summary = f"No data available for {symbol} in the specified period."
134
- investment_match = re.search(r'\$(\d+)', user_query)
135
- if investment_match:
136
- principal = float(investment_match.group(1))
137
- years = (end_date - start_date).days / 365.25
138
- projected = calculate_investment(principal, years)
139
- summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}."
140
- system_prompt = (
141
- "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
142
- "For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
143
- "Provide accurate, concise advice based on data."
144
- )
145
- # Placeholder for generation logic (tokenize, generate, decode)
146
- return summary or "Please provide a specific stock or investment query."
147
-
148
- # Gradio interface setup
149
- demo = gr.Interface(
150
- fn=generate_response,
151
- inputs=[gr.Textbox(lines=2, placeholder="Enter your query (e.g., 'TSLA CAGR between 2010 and 2020')"), gr.Checkbox(label="Enable Thinking")],
152
- outputs="text",
153
- title="FinChat",
154
- description="Ask about stock performance, CAGR, or investments."
155
- )
156
- demo.launch()
 
12
  # Define the list of tickers
13
  tickers = ['TSLA', 'PLTR', 'SOUN', 'MSFT']
14
 
15
+ # Company name to ticker mapping for better query handling
16
+ company_to_ticker = {
17
+ 'tesla': 'TSLA',
18
+ 'palantir': 'PLTR',
19
+ 'soundhound': 'SOUN',
20
+ 'microsoft': 'MSFT',
21
+ 'nvidia': 'NVDA',
22
+ 'google': 'GOOG',
23
+ 'amazon': 'AMZN',
24
+ 'apple': 'AAPL',
25
+ 'meta': 'META',
26
+ 'netflix': 'NFLX',
27
+ 'intel': 'INTC',
28
+ 'amd': 'AMD',
29
+ 'ibm': 'IBM',
30
+ 'oracle': 'ORCL',
31
+ 'cisco': 'CSCO',
32
+ 'jpmorgan': 'JPM',
33
+ 'bank of america': 'BAC',
34
+ 'wells fargo': 'WFC',
35
+ 'visa': 'V',
36
+ 'mastercard': 'MA',
37
+ 'exxon': 'XOM',
38
+ 'chevron': 'CVX',
39
+ 'pfizer': 'PFE',
40
+ 'johnson & johnson': 'JNJ',
41
+ 'merck': 'MRK',
42
+ 'spy': 'SPY'
43
+ }
44
+
45
  # Prefetch stock data for all tickers at startup using yfinance
46
  all_data = {}
47
  try:
 
52
  print(f"Error prefetching data: {e}")
53
  all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
54
 
55
+ # Create a DataFrame with 'Adj Close' columns for each ticker
56
  series_list = []
57
  for ticker, data in all_data.items():
58
  if not data.empty:
59
  s = data['Close']
60
  s.name = ticker
61
  series_list.append(s)
62
+ adj_close_data = pd.concat(series_list, axis=1) if series_list else pd.DataFrame()
63
 
64
  # Display the first few rows to verify (for debugging; remove in production)
65
  print(adj_close_data.head())
 
84
  return None
85
 
86
  def parse_dates(query):
87
+ # Handle year ranges like "between 2010 and 2020"
88
  range_match = re.search(r'between\s+(\d{4})\s+and\s+(\d{4})', query.lower())
89
  if range_match:
90
  start_year = int(range_match.group(1))
91
  end_year = int(range_match.group(2))
92
+ try:
93
+ start_date = datetime(start_year, 1, 1)
94
+ end_date = datetime(end_year, 12, 31)
95
+ if start_date >= end_date:
96
+ raise ValueError("Start date must be before end date")
97
+ return start_date, end_date
98
+ except ValueError as e:
99
+ print(f"Date parsing error: {e}")
100
+ return None, None
101
+ # Fallback to period parsing for recent periods
102
  period_match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower())
103
  if period_match:
104
  num = int(period_match.group(1))
 
123
 
124
  def find_closest_symbol(input_symbol):
125
  input_symbol = input_symbol.upper()
126
+ # Check if input matches a company name
127
+ for company, ticker in company_to_ticker.items():
128
+ if company in input_symbol.lower():
129
+ return ticker
130
+ # Fallback to ticker matching
131
  closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6)
132
  return closest[0] if closest else None
133
 
134
  def calculate_growth_rate(start_date, end_date, symbol):
135
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
136
+ if hist is