AnilNiraula commited on
Commit
f3373ab
·
verified ·
1 Parent(s): ff3b207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -13
app.py CHANGED
@@ -52,10 +52,10 @@ except Exception as e:
52
  print(f"Error prefetching data: {e}")
53
  all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
54
 
55
- # Create a DataFrame with 'Adj Close' columns for each ticker
56
  series_list = []
57
  for ticker, data in all_data.items():
58
- if not data.empty:
59
  s = data['Close']
60
  s.name = ticker
61
  series_list.append(s)
@@ -69,19 +69,25 @@ available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'MET
69
 
70
  @lru_cache(maxsize=100)
71
  def fetch_stock_data(symbol, start_date, end_date):
72
- if symbol in all_data and not all_data[symbol].empty:
73
- # Use preloaded data and slice by date
74
- hist = all_data[symbol]
75
- return hist[(hist.index >= start_date) & (hist.index <= end_date)]
76
- else:
77
- # Fetch on-demand with yfinance
78
- try:
 
 
 
 
 
 
79
  ticker = yf.Ticker(symbol)
80
  hist = ticker.history(start=start_date, end=end_date, auto_adjust=True)
81
  return hist
82
- except Exception as e:
83
- print(f"Error fetching data for {symbol}: {e}")
84
- return None
85
 
86
  def parse_dates(query):
87
  # Handle year ranges like "between 2010 and 2020"
@@ -133,4 +139,86 @@ def find_closest_symbol(input_symbol):
133
 
134
  def calculate_growth_rate(start_date, end_date, symbol):
135
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
136
- if hist is
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  print(f"Error prefetching data: {e}")
53
  all_data = {ticker: pd.DataFrame() for ticker in tickers} # Initialize empty DataFrames on failure
54
 
55
+ # Create a DataFrame with 'Close' columns for each ticker
56
  series_list = []
57
  for ticker, data in all_data.items():
58
+ if not data.empty and 'Close' in data.columns:
59
  s = data['Close']
60
  s.name = ticker
61
  series_list.append(s)
 
69
 
70
  @lru_cache(maxsize=100)
71
  def fetch_stock_data(symbol, start_date, end_date):
72
+ try:
73
+ # Validate dates
74
+ start = datetime.strptime(start_date, '%Y-%m-%d')
75
+ end = datetime.strptime(end_date, '%Y-%m-%d')
76
+ if start >= end:
77
+ print(f"Invalid date range: {start_date} to {end_date}")
78
+ return None
79
+ if symbol in all_data and not all_data[symbol].empty:
80
+ # Use preloaded data and slice by date
81
+ hist = all_data[symbol]
82
+ return hist[(hist.index >= start_date) & (hist.index <= end_date)]
83
+ else:
84
+ # Fetch on-demand with yfinance
85
  ticker = yf.Ticker(symbol)
86
  hist = ticker.history(start=start_date, end=end_date, auto_adjust=True)
87
  return hist
88
+ except Exception as e:
89
+ print(f"Error fetching data for {symbol}: {e}")
90
+ return None
91
 
92
  def parse_dates(query):
93
  # Handle year ranges like "between 2010 and 2020"
 
139
 
140
  def calculate_growth_rate(start_date, end_date, symbol):
141
  hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
142
+ if hist is None or hist.empty or 'Close' not in hist.columns:
143
+ print(f"No valid data for {symbol} from {start_date} to {end_date}")
144
+ return None
145
+ beginning_value = hist.iloc[0]['Close']
146
+ ending_value = hist.iloc[-1]['Close']
147
+ years = (end_date - start_date).days / 365.25
148
+ if years <= 0:
149
+ return 0
150
+ total_return = ending_value / beginning_value
151
+ cagr = total_return ** (1 / years) - 1
152
+ return cagr * 100
153
+
154
+ def calculate_investment(principal, years, annual_return=0.07):
155
+ return principal * (1 + annual_return) ** years
156
+
157
+ model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
158
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
159
+ model = AutoModelForCausalLM.from_pretrained(
160
+ model_name,
161
+ device_map="auto",
162
+ )
163
+
164
+ def generate_response(user_query, enable_thinking=False):
165
+ print(f"Processing query: {user_query}") # Debugging
166
+ stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
167
+ is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
168
+ print(f"Is stock query: {is_stock_query}") # Debugging
169
+ summary = ""
170
+ if is_stock_query:
171
+ # Try to find symbol from company name or ticker
172
+ symbol = None
173
+ for company, ticker in company_to_ticker.items():
174
+ if company in user_query.lower():
175
+ symbol = ticker
176
+ break
177
+ if not symbol:
178
+ symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
179
+ symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None
180
+ print(f"Detected symbol: {symbol}") # Debugging
181
+ if symbol:
182
+ start_date, end_date = parse_dates(user_query)
183
+ print(f"Parsed dates: {start_date} to {end_date}") # Debugging
184
+ if start_date is None or end_date is None:
185
+ return "Invalid date range provided. Please specify a valid range, e.g., 'between 2010 and 2020'."
186
+ hist = fetch_stock_data(symbol, start_date.strftime('%Y-%m-%d'), end_date.strftime('%Y-%m-%d'))
187
+ print(f"Data fetched for {symbol}: {'Valid' if hist is not None and not hist.empty else 'Empty/None'}") # Debugging
188
+ if "average price" in user_query.lower():
189
+ if hist is not None and not hist.empty and 'Close' in hist.columns:
190
+ avg_price = hist['Close'].mean()
191
+ summary = f"The average adjusted closing price for {symbol} from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is ${avg_price:.2f}."
192
+ else:
193
+ summary = f"No data available for {symbol} in the specified period."
194
+ elif "cagr" in user_query.lower() or "return" in user_query.lower():
195
+ growth_rate = calculate_growth_rate(start_date, end_date, symbol)
196
+ if growth_rate is not None:
197
+ summary = f"The CAGR for {symbol} over the period from {start_date.strftime('%Y-%m-%d')} to {end_date.strftime('%Y-%m-%d')} is {growth_rate:.2f}%."
198
+ else:
199
+ summary = f"No data available for {symbol} in the specified period."
200
+ investment_match = re.search(r'\$(\d+)', user_query)
201
+ if investment_match:
202
+ principal = float(investment_match.group(1))
203
+ years = (end_date - start_date).days / 365.25
204
+ projected = calculate_investment(principal, years)
205
+ summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}."
206
+ else:
207
+ summary = "Could not identify a valid stock symbol in the query. Please specify a company or ticker (e.g., 'Tesla' or 'TSLA')."
208
+ system_prompt = (
209
+ "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
210
+ "For greetings like 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' "
211
+ "Provide accurate, concise advice based on data."
212
+ )
213
+ # Placeholder for generation logic (tokenize, generate, decode)
214
+ return summary or "Please provide a specific stock or investment query."
215
+
216
+ # Gradio interface setup
217
+ demo = gr.Interface(
218
+ fn=generate_response,
219
+ inputs=[gr.Textbox(lines=2, placeholder="Enter your query (e.g., 'Tesla average price between 2010 and 2020')"), gr.Checkbox(label="Enable Thinking")],
220
+ outputs="text",
221
+ title="FinChat",
222
+ description="Ask about stock performance, CAGR, or investments."
223
+ )
224
+ demo.launch()