Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import re | |
| from datetime import datetime, timedelta | |
| import difflib | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import yfinance as yf # Added for dynamic data fetching | |
| # Hardcoded common symbols for fuzzy matching (expand as needed) | |
| available_symbols = ['TSLA', 'MSFT', 'NVDA', 'GOOG', 'AMZN', 'SPY'] | |
| # Financial calculation functions (updated for dynamic data) | |
| def parse_period(query): | |
| match = re.search(r'(\d+)\s*(year|month|week|day)s?', query.lower()) | |
| if match: | |
| num = int(match.group(1)) | |
| unit = match.group(2) | |
| if unit == 'year': | |
| return timedelta(days=365 * num) | |
| elif unit == 'month': | |
| return timedelta(days=30 * num) | |
| elif unit == 'week': | |
| return timedelta(weeks=num) | |
| elif unit == 'day': | |
| return timedelta(days=num) | |
| return timedelta(days=365) # Default to 1 year | |
| def find_closest_symbol(input_symbol): | |
| input_symbol = input_symbol.upper() | |
| closest = difflib.get_close_matches(input_symbol, available_symbols, n=1, cutoff=0.6) | |
| return closest[0] if closest else None | |
| def calculate_growth_rate(start_date, end_date, symbol): | |
| try: | |
| ticker = yf.Ticker(symbol) | |
| hist = ticker.history(start=start_date, end=end_date) | |
| if hist.empty: | |
| return None | |
| beginning_value = hist.iloc[0]['Close'] # Use 'Close' if 'Adj Close' not available; adjust as needed | |
| ending_value = hist.iloc[-1]['Close'] | |
| dividends = hist['Dividends'].sum() if 'Dividends' in hist.columns else 0 | |
| total_return = (ending_value - beginning_value + dividends) / beginning_value | |
| years = (end_date - start_date).days / 365.25 | |
| if years == 0: | |
| return 0 | |
| cagr = (1 + total_return) ** (1 / years) - 1 | |
| return cagr * 100 # As percentage | |
| except Exception as e: | |
| print(f"Error fetching data for {symbol}: {e}") | |
| return None | |
| def calculate_investment(principal, years, annual_return=0.07): | |
| return principal * (1 + annual_return) ** years | |
| # Load SmolLM3 model and tokenizer | |
| model_name = "HuggingFaceTB/SmolLM3-3B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", # Auto-detect GPU/CPU | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| # Optional: load_in_8bit=True for quantization (requires bitsandbytes) | |
| ) | |
| def generate_response(user_query, enable_thinking=True): | |
| # Parse query for symbol and period | |
| symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper()) | |
| symbol = find_closest_symbol(symbol_match.group(1)) if symbol_match else 'SPY' # Default to SPY | |
| period = parse_period(user_query) | |
| end_date = datetime.now() | |
| start_date = end_date - period | |
| # Calculate growth rate if applicable | |
| growth_rate = calculate_growth_rate(start_date, end_date, symbol) | |
| if growth_rate is not None: | |
| summary = f"The CAGR for {symbol} over the period is {growth_rate:.2f}%." | |
| else: | |
| summary = f"No data available for {symbol} in the specified period." | |
| # Handle investment projection | |
| investment_match = re.search(r'\$(\d+)', user_query) | |
| if investment_match: | |
| principal = float(investment_match.group(1)) | |
| years = period.days / 365.25 | |
| projected = calculate_investment(principal, years) | |
| summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}." | |
| # Prepare prompt for model | |
| system_prompt = "You are a knowledgeable financial advisor. Provide accurate, concise advice based on the data provided. Disclaimer: This is not professional advice." | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"{summary} {user_query}"} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=enable_thinking # Enable deep reasoning for complex queries | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=100, | |
| temperature=0.6, | |
| top_p=0.95, | |
| repetition_penalty=1.2, | |
| do_sample=True | |
| ) | |
| output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] | |
| response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| return response.strip() | |
| # Gradio interface | |
| def chat(user_input, history): | |
| response = generate_response(user_input) | |
| history.append((user_input, response)) | |
| return history, "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# FinChat: AI-Powered Financial Advisor") | |
| chatbot = gr.Chatbot() | |
| msg = gr.Textbox(placeholder="Ask about stocks, investments, etc.") | |
| clear = gr.Button("Clear") | |
| msg.submit(chat, [msg, chatbot], [chatbot, msg]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch() |