import gradio as gr import numpy as np import re from datetime import datetime, timedelta import difflib from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch import yfinance as yf # For dynamic data fetching # Hardcoded common symbols for fuzzy matching (expand as needed) available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY'] # Financial calculation functions def parse_period(query): match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower()) if match: num = int(match.group(1)) unit = match.group(2) if unit == 'year': return timedelta(days=365 * num) elif unit == 'month': return timedelta(days=30 * num) elif unit == 'week': return timedelta(weeks=num) elif unit == 'day': return timedelta(days=num) return timedelta(days=365) # Default to 1 year def find_closest_symbol(input_symbol): input_symbol = input_symbol.upper() closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6) return closest[0] if closest else None def calculate_growth_rate(start_date, end_date, symbol): try: ticker = yf.Ticker(symbol) hist = ticker.history(start=start_date, end=end_date) if hist.empty: return None beginning_value = hist.iloc[0]['Close'] ending_value = hist.iloc[-1]['Close'] dividends = hist['Dividends'].sum() if 'Dividends' in hist.columns else 0 total_return = (ending_value - beginning_value + dividends) / beginning_value years = (end_date - start_date).days / 365.25 if years == 0: return 0 cagr = (1 + total_return) ** (1 / years) - 1 return cagr * 100 # As percentage except Exception as e: print(f"Error fetching data for {symbol}: {e}") return None def calculate_investment(principal, years, annual_return=0.07): return principal * (1 + annual_return) ** years # Load SmolLM-135M-Instruct model with 4-bit quantization for speed model_name = "HuggingFaceTB/SmolLM-135M-Instruct" quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=quantization_config, device_map="auto", ) def generate_response(user_query, enable_thinking=False): # Parse query for symbol and period symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper()) symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else 'SPY' # Default to SPY period = parse_period(user_query) end_date = datetime.now() start_date = end_date - period # Calculate growth rate if applicable growth_rate = calculate_growth_rate(start_date, end_date, symbol) if growth_rate is not None: summary = f"The CAGR for {symbol} over the period is {growth_rate:.2f}%." else: summary = f"No data available for {symbol} in the specified period." # Handle investment projection investment_match = re.search(r'\$(\d+)', user_query) if investment_match: principal = float(investment_match.group(1)) years = period.days / 365.25 projected = calculate_investment(principal, years) summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}." # Prepare prompt for model system_prompt = "You are a knowledgeable financial advisor. Provide accurate, concise advice based on the data provided. Disclaimer: This is not professional advice." messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"{summary} {user_query}"} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking # Disabled for speed ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate( **model_inputs, max_new_tokens=50, # Reduced for faster generation temperature=0.6, top_p=0.95, repetition_penalty=1.0, # Set to 1.0 to avoid overhead do_sample=False # Greedy decoding for speed ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] response = tokenizer.decode(output_ids, skip_special_tokens=True) return response.strip() # Gradio interface def chat(user_input, history): response = generate_response(user_input) history.append((user_input, response)) return history, "" with gr.Blocks() as demo: gr.Markdown("# FinChat: AI-Powered Financial Advisor") chatbot = gr.Chatbot() msg = gr.Textbox(placeholder="Ask about stocks, investments, etc.") clear = gr.Button("Clear") msg.submit(chat, [msg, chatbot], [chatbot, msg]) clear.click(lambda: None, None, chatbot, queue=False) demo.launch()