Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import re | |
from datetime import datetime, timedelta | |
import difflib | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
import yfinance as yf # For dynamic data fetching | |
# Expanded list of common symbols for fuzzy matching (added more tickers for better coverage) | |
available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY', 'AAPL', 'META', 'NFLX', 'INTC', 'AMD', 'IBM', 'ORCL', 'CSCO', 'JPM', 'BAC', 'WFC', 'V', 'MA', 'XOM', 'CVX', 'PFE', 'JNJ', 'MRK'] | |
# Financial calculation functions | |
def parse_period(query): | |
match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower()) | |
if match: | |
num = int(match.group(1)) | |
unit = match.group(2) | |
if unit == 'year': | |
return timedelta(days=365 * num) | |
elif unit == 'month': | |
return timedelta(days=30 * num) | |
elif unit == 'week': | |
return timedelta(weeks=num) | |
elif unit == 'day': | |
return timedelta(days=num) | |
return timedelta(days=365) # Default to 1 year | |
def find_closest_symbol(input_symbol): | |
input_symbol = input_symbol.upper() | |
closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6) | |
return closest[0] if closest else None | |
def calculate_growth_rate(start_date, end_date, symbol): | |
try: | |
ticker = yf.Ticker(symbol) | |
hist = ticker.history(start=start_date, end=end_date) | |
if hist.empty: | |
return None | |
beginning_value = hist.iloc[0]['Close'] | |
ending_value = hist.iloc[-1]['Close'] | |
dividends = hist['Dividends'].sum() if 'Dividends' in hist.columns else 0 | |
total_return = (ending_value - beginning_value + dividends) / beginning_value | |
years = (end_date - start_date).days / 365.25 | |
if years == 0: | |
return 0 | |
cagr = (1 + total_return) ** (1 / years) - 1 | |
return cagr * 100 # As percentage | |
except Exception as e: | |
print(f"Error fetching data for {symbol}: {e}") | |
return None | |
def calculate_investment(principal, years, annual_return=0.07): | |
return principal * (1 + annual_return) ** years | |
# Load SmolLM-135M-Instruct model with 4-bit quantization for speed | |
model_name = "HuggingFaceTB/SmolLM-135M-Instruct" | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=quantization_config, | |
device_map="auto", | |
) | |
def generate_response(user_query, enable_thinking=False): | |
# Parse query for symbol and period | |
symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper()) | |
symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else None # No default; skip if no match | |
summary = "" # Initialize empty to handle general queries | |
if symbol: # Only perform calculations if a valid symbol is detected | |
period = parse_period(user_query) | |
end_date = datetime.now() | |
start_date = end_date - period | |
# Calculate growth rate if applicable | |
growth_rate = calculate_growth_rate(start_date, end_date, symbol) | |
if growth_rate is not None: | |
summary = f"The CAGR for {symbol} over the period is {growth_rate:.2f}%." | |
else: | |
summary = f"No data available for {symbol} in the specified period." | |
# Handle investment projection | |
investment_match = re.search(r'\$(\d+)', user_query) | |
if investment_match: | |
principal = float(investment_match.group(1)) | |
years = period.days / 365.25 | |
projected = calculate_investment(principal, years) | |
summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}." | |
# Prepare prompt for model with enhanced instructions for natural chatbot behavior | |
system_prompt = ( | |
"You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner like a helpful chatbot. " | |
"For greetings such as 'Hi' or 'Hello', reply warmly, e.g., 'Hi! I'm FinChat, your financial advisor. What can I help you with today regarding stocks, investments, or advice?' " | |
"Provide accurate, concise advice based on any provided data. If no specific data is available, offer general financial insights or ask for clarification. " | |
"Disclaimer: This is not professional financial advice; consult experts for decisions." | |
) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": f"{summary} {user_query}" if summary else user_query} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=enable_thinking # Disabled for speed | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=50, # Reduced for faster generation | |
temperature=0.6, | |
top_p=0.95, | |
repetition_penalty=1.0, # Set to 1.0 to avoid overhead | |
do_sample=False # Greedy decoding for speed | |
) | |
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] | |
response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
return response.strip() | |
# Gradio interface | |
def chat(user_input, history): | |
response = generate_response(user_input) | |
history.append((user_input, response)) | |
return history, "" | |
with gr.Blocks() as demo: | |
gr.Markdown("# FinChat: AI-Powered Financial Advisor") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(placeholder="Ask about stocks, investments, etc.") | |
clear = gr.Button("Clear") | |
msg.submit(chat, [msg, chatbot], [chatbot, msg]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.launch() |