Spaces:
Sleeping
Sleeping
Update finetuned_model.py
Browse files- finetuned_model.py +43 -903
finetuned_model.py
CHANGED
|
@@ -1,920 +1,60 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
import time
|
| 6 |
import torch
|
| 7 |
-
import gradio as gr
|
| 8 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
-
import pandas as pd
|
| 10 |
-
import re
|
| 11 |
-
import numpy as np
|
| 12 |
-
import json
|
| 13 |
-
import difflib
|
| 14 |
-
from difflib import SequenceMatcher
|
| 15 |
-
from datetime import datetime
|
| 16 |
-
|
| 17 |
-
# Set up logging
|
| 18 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 19 |
-
logger = logging.getLogger(__name__)
|
| 20 |
-
|
| 21 |
-
# Define device
|
| 22 |
-
device = torch.device("cpu") # Explicitly set to CPU for HF free tier
|
| 23 |
-
logger.info(f"Using device: {device}")
|
| 24 |
-
|
| 25 |
-
# Load dataset
|
| 26 |
-
csv_path = "stock_data.csv"
|
| 27 |
-
try:
|
| 28 |
-
df = pd.read_csv(csv_path)
|
| 29 |
-
df['Date'] = pd.to_datetime(df['Date'])
|
| 30 |
-
df = df.sort_values('Date')
|
| 31 |
-
logger.info("Loaded dataset successfully")
|
| 32 |
-
except Exception as e:
|
| 33 |
-
logger.error(f"Error loading dataset: {e}")
|
| 34 |
-
df = None
|
| 35 |
-
|
| 36 |
-
# Precompute yearly aggregates
|
| 37 |
-
STOCK_SYMBOLS = ["TSLA", "MSFT", "NVDA", "GOOG", "AMZN", "SPY"]
|
| 38 |
-
if df is not None:
|
| 39 |
-
df_yearly = df.groupby(df['Date'].dt.year).agg({
|
| 40 |
-
**{f'Price_{symbol}': 'mean' for symbol in STOCK_SYMBOLS},
|
| 41 |
-
**{f'Return_{symbol}': 'mean' for symbol in STOCK_SYMBOLS},
|
| 42 |
-
**{f'Real_Return_{symbol}': 'mean' for symbol in STOCK_SYMBOLS},
|
| 43 |
-
**{f'Dividend_{symbol}': 'mean' for symbol in STOCK_SYMBOLS},
|
| 44 |
-
**{f'Earnings_{symbol}': 'mean' for symbol in STOCK_SYMBOLS},
|
| 45 |
-
**{f'PE10_{symbol}': 'mean' for symbol in STOCK_SYMBOLS}
|
| 46 |
-
}).reset_index()
|
| 47 |
-
df_yearly = df_yearly.rename(columns={'Date': 'Year'})
|
| 48 |
-
else:
|
| 49 |
-
df_yearly = None
|
| 50 |
-
|
| 51 |
-
# Symbol mapping for natural language queries
|
| 52 |
-
symbol_map = {
|
| 53 |
-
"tesla": "TSLA",
|
| 54 |
-
"microsoft": "MSFT",
|
| 55 |
-
"nvidia": "NVDA",
|
| 56 |
-
"google": "GOOG",
|
| 57 |
-
"alphabet": "GOOG",
|
| 58 |
-
"amazon": "AMZN",
|
| 59 |
-
"s&p 500": "SPY",
|
| 60 |
-
"spy": "SPY"
|
| 61 |
-
}
|
| 62 |
-
|
| 63 |
-
# Response cache
|
| 64 |
-
response_cache = {
|
| 65 |
-
"hi": "Hello! I'm FinChat, your financial advisor. How can I help with investing?",
|
| 66 |
-
"hello": "Hello! I'm FinChat, your financial advisor. How can I help with investing?",
|
| 67 |
-
"hey": "Hi there! Ready to discuss investment goals with FinChat?",
|
| 68 |
-
"what is better individual stocks or etfs?": (
|
| 69 |
-
"Here’s a comparison of individual stocks vs. ETFs:\n"
|
| 70 |
-
"1. **Individual Stocks**: High returns possible (e.g., TSLA up ~743% in 2020) but riskier due to lack of diversification. Require active research.\n"
|
| 71 |
-
"2. **ETFs**: Diversify risk by tracking indices (e.g., SPY, S&P 500, ~12% avg. return 2015–2024). Lower fees and less research needed.\n"
|
| 72 |
-
"3. **Recommendation**: Beginners should start with ETFs; experienced investors may add stocks like MSFT or AMZN.\n"
|
| 73 |
-
"Consult a financial planner."
|
| 74 |
-
),
|
| 75 |
-
"is $100 per month enough to invest?": (
|
| 76 |
-
"Yes, $100 per month is enough to start investing. Here’s why and how:\n"
|
| 77 |
-
"1. **Feasibility**: Brokerages like Fidelity have no minimums, and commission-free trading eliminates fees.\n"
|
| 78 |
-
"2. **Options**: Buy fractional shares of ETFs (e.g., SPY, ~$622/share in 2025) or stocks like AMZN with $100.\n"
|
| 79 |
-
"3. **Strategy**: Use dollar-cost averaging to invest monthly, reducing market timing risks.\n"
|
| 80 |
-
"4. **Growth**: At 10% annual return, $100 monthly could grow to ~$41,000 in 20 years.\n"
|
| 81 |
-
"5. **Tips**: Ensure an emergency fund; diversify.\n"
|
| 82 |
-
"Consult a financial planner."
|
| 83 |
-
),
|
| 84 |
-
"can i invest $100 a month?": (
|
| 85 |
-
"Yes, $100 a month is sufficient. Here’s how:\n"
|
| 86 |
-
"1. **Brokerage**: Open an account with Fidelity or Vanguard (no minimums).\n"
|
| 87 |
-
"2. **Investments**: Buy fractional shares of ETFs like SPY ($100 buys ~0.16 shares in 2025) or stocks like GOOG.\n"
|
| 88 |
-
"3. **Approach**: Use dollar-cost averaging for steady growth.\n"
|
| 89 |
-
"4. **Long-Term**: At 10% return, $100 monthly could reach ~$41,000 in 20 years.\n"
|
| 90 |
-
"5. **Tips**: Prioritize an emergency fund and diversify.\n"
|
| 91 |
-
"Consult a financial planner."
|
| 92 |
-
),
|
| 93 |
-
"hi, give me step-by-step investing advice": (
|
| 94 |
-
"Here’s a step-by-step guide to start investing:\n"
|
| 95 |
-
"1. Open a brokerage account (e.g., Fidelity, Vanguard) if 18 or older.\n"
|
| 96 |
-
"2. Deposit an affordable amount, like $100, after an emergency fund.\n"
|
| 97 |
-
"3. Research and buy an ETF (e.g., SPY) or stock (e.g., MSFT) using Yahoo Finance.\n"
|
| 98 |
-
"4. Monitor monthly and enable dividend reinvesting.\n"
|
| 99 |
-
"5. Use dollar-cost averaging ($100 monthly) to reduce risk.\n"
|
| 100 |
-
"6. Diversify across sectors.\n"
|
| 101 |
-
"Consult a financial planner."
|
| 102 |
-
),
|
| 103 |
-
"hi, pretend you are a financial advisor. now tell me how can i start investing in stock market?": (
|
| 104 |
-
"Here’s a guide to start investing:\n"
|
| 105 |
-
"1. Learn from Investopedia or 'The Intelligent Investor.'\n"
|
| 106 |
-
"2. Set goals (e.g., retirement) and assess risk.\n"
|
| 107 |
-
"3. Choose a brokerage (Fidelity, Vanguard).\n"
|
| 108 |
-
"4. Start with ETFs (e.g., SPY) or stocks (e.g., NVDA).\n"
|
| 109 |
-
"5. Use dollar-cost averaging ($100-$500 monthly).\n"
|
| 110 |
-
"6. Diversify and monitor.\n"
|
| 111 |
-
"Consult a financial planner."
|
| 112 |
-
),
|
| 113 |
-
"do you have a list of companies you recommend?": (
|
| 114 |
-
"I can’t recommend specific companies without real-time data. Try ETFs like SPY (S&P 500, ~12% avg. return 2015–2024) or QQQ (Nasdaq-100). "
|
| 115 |
-
"Research stocks like MSFT (~26% avg. return 2015–2024) or AMZN on Yahoo Finance.\n"
|
| 116 |
-
"Consult a financial planner."
|
| 117 |
-
),
|
| 118 |
-
"how do i start investing in stocks?": (
|
| 119 |
-
"Learn from Investopedia. Set goals and assess risk. Open a brokerage account (Fidelity, Vanguard) "
|
| 120 |
-
"and start with ETFs (e.g., SPY, ~12% avg. return 2015–2024) or stocks like GOOG. Consult a financial planner."
|
| 121 |
-
),
|
| 122 |
-
"what's the difference between stocks and bonds?": (
|
| 123 |
-
"Stocks are company ownership with high risk and growth potential (e.g., TSLA ~743% in 2020). Bonds are loans to companies/governments "
|
| 124 |
-
"with lower risk and steady interest. Diversify for balance."
|
| 125 |
-
),
|
| 126 |
-
"how much should i invest?": (
|
| 127 |
-
"Invest what you can afford after expenses and an emergency fund. Start with $100-$500 monthly "
|
| 128 |
-
"in ETFs like SPY (~12% avg. return 2015–2024) or stocks like NVDA. Consult a financial planner."
|
| 129 |
-
),
|
| 130 |
-
"what is dollar-cost averaging?": (
|
| 131 |
-
"Dollar-cost averaging is investing a fixed amount regularly (e.g., $100 monthly) in ETFs like SPY or stocks like AMZN, "
|
| 132 |
-
"reducing risk by spreading purchases over time."
|
| 133 |
-
),
|
| 134 |
-
"give me few investing idea": (
|
| 135 |
-
"Here are investing ideas:\n"
|
| 136 |
-
"1. Open a brokerage account (e.g., Fidelity) if 18 or older.\n"
|
| 137 |
-
"2. Deposit $100 or what you can afford.\n"
|
| 138 |
-
"3. Buy a researched ETF (e.g., SPY, ~12% avg. return 2015–2024) or stock (e.g., MSFT).\n"
|
| 139 |
-
"4. Check regularly and enable dividend reinvesting.\n"
|
| 140 |
-
"5. Use dollar-cost averaging (e.g., monthly buys).\n"
|
| 141 |
-
"Consult a financial planner."
|
| 142 |
-
),
|
| 143 |
-
"give me investing tips": (
|
| 144 |
-
"Here are investing tips:\n"
|
| 145 |
-
"1. Educate yourself with Investopedia or books.\n"
|
| 146 |
-
"2. Open a brokerage account (e.g., Vanguard).\n"
|
| 147 |
-
"3. Start small with ETFs like SPY (~12% avg. return 2015–2024) or stocks like GOOG.\n"
|
| 148 |
-
"4. Invest regularly using dollar-cost averaging.\n"
|
| 149 |
-
"5. Diversify to manage risk.\n"
|
| 150 |
-
"Consult a financial planner."
|
| 151 |
-
),
|
| 152 |
-
"how to start investing": (
|
| 153 |
-
"Here’s how to start investing:\n"
|
| 154 |
-
"1. Study basics on Investopedia.\n"
|
| 155 |
-
"2. Open a brokerage account (e.g., Fidelity).\n"
|
| 156 |
-
"3. Deposit $100 or more after securing savings.\n"
|
| 157 |
-
"4. Buy an ETF like SPY (~12% avg. return 2015–2024) or stock like AMZN after research.\n"
|
| 158 |
-
"5. Invest monthly with dollar-cost averaging.\n"
|
| 159 |
-
"Consult a financial planner."
|
| 160 |
-
),
|
| 161 |
-
"investing advice": (
|
| 162 |
-
"Here’s investing advice:\n"
|
| 163 |
-
"1. Learn basics from Investopedia.\n"
|
| 164 |
-
"2. Open a brokerage account (e.g., Vanguard).\n"
|
| 165 |
-
"3. Start with $100 in an ETF like SPY (~12% avg. return 2015–2024) or stock like NVDA.\n"
|
| 166 |
-
"4. Use dollar-cost averaging for regular investments.\n"
|
| 167 |
-
"5. Monitor and diversify your portfolio.\n"
|
| 168 |
-
"Consult a financial planner."
|
| 169 |
-
),
|
| 170 |
-
"steps to invest": (
|
| 171 |
-
"Here are steps to invest:\n"
|
| 172 |
-
"1. Educate yourself using Investopedia.\n"
|
| 173 |
-
"2. Open a brokerage account (e.g., Fidelity).\n"
|
| 174 |
-
"3. Deposit an initial $100 after savings.\n"
|
| 175 |
-
"4. Buy an ETF like SPY (~12% avg. return 2015–2024) or stock like MSFT after research.\n"
|
| 176 |
-
"5. Use dollar-cost averaging monthly.\n"
|
| 177 |
-
"Consult a financial planner."
|
| 178 |
-
),
|
| 179 |
-
"what is the average growth rate for stocks?": (
|
| 180 |
-
"The average annual return for individual stocks varies widely, but broad market indices like the S&P 500 average 10–12% over the long term (1927–2025), including dividends. "
|
| 181 |
-
"Specific stocks like TSLA or NVDA may have higher volatility and returns. Consult a financial planner."
|
| 182 |
-
)
|
| 183 |
-
}
|
| 184 |
-
|
| 185 |
-
# Load persistent cache
|
| 186 |
-
cache_file = "cache.json"
|
| 187 |
-
try:
|
| 188 |
-
if os.path.exists(cache_file):
|
| 189 |
-
with open(cache_file, 'r') as f:
|
| 190 |
-
response_cache.update(json.load(f))
|
| 191 |
-
logger.info("Loaded persistent cache from cache.json")
|
| 192 |
-
except Exception as e:
|
| 193 |
-
logger.warning(f"Failed to load cache.json: {e}")
|
| 194 |
|
| 195 |
# Load model and tokenizer
|
| 196 |
-
model_name = "
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 203 |
-
model_name,
|
| 204 |
-
torch_dtype=torch.float32, # Changed to float32 to avoid Half/Float mismatch
|
| 205 |
-
low_cpu_mem_usage=True
|
| 206 |
-
).to(device)
|
| 207 |
-
logger.info(f"Successfully loaded model: {model_name}")
|
| 208 |
-
except Exception as e:
|
| 209 |
-
logger.error(f"Error loading model/tokenizer: {e}")
|
| 210 |
-
raise RuntimeError(f"Failed to load model: {str(e)}")
|
| 211 |
-
|
| 212 |
-
# Shortened prompt prefix for faster processing
|
| 213 |
-
prompt_prefix = (
|
| 214 |
-
"You are FinChat, a financial advisor. Use provided data for historical returns and calculate accurately. Provide detailed advice with reasoning. Use compound calculations for returns. Avoid invented facts. Keep responses under 100 words.\n\n"
|
| 215 |
-
"Example:\n"
|
| 216 |
-
"Q: What was MSFT return 2010-2020?\n"
|
| 217 |
-
"A: MSFT CAGR ~16.8% from 2010-2020, including dividends. Tech growth drove this; dividends added to totals.\n\n"
|
| 218 |
-
"Q: "
|
| 219 |
)
|
| 220 |
-
prefix_tokens = tokenizer(prompt_prefix, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 221 |
-
|
| 222 |
-
# Substring matching for cache with enhanced fuzzy matching
|
| 223 |
-
def get_closest_cache_key(message, cache_keys):
|
| 224 |
-
message = re.sub(r'[^\w\s]', '', message.lower().strip()) # Remove punctuation and normalize
|
| 225 |
-
ratios = {k: SequenceMatcher(None, message, k).ratio() for k in cache_keys}
|
| 226 |
-
max_ratio = max(ratios.values())
|
| 227 |
-
if max_ratio >= 0.7:
|
| 228 |
-
return max(ratios, key=ratios.get)
|
| 229 |
-
return None
|
| 230 |
-
|
| 231 |
-
# Parse period from user input with expanded regex for better coverage
|
| 232 |
-
def parse_period(query):
|
| 233 |
-
query = query.lower()
|
| 234 |
-
# Match symbol with mapping
|
| 235 |
-
symbol_match = re.search(r'(tsla|msft|nvda|goog|amzn|s&p\s*500|tesla|microsoft|nvidia|google|alphabet|amazon|spy)', query)
|
| 236 |
-
symbol_key = symbol_match.group(1) if symbol_match else "spy"
|
| 237 |
-
symbol = symbol_map.get(symbol_key, symbol_key.upper())
|
| 238 |
-
if symbol == "S&P 500":
|
| 239 |
-
symbol = "SPY"
|
| 240 |
-
# Match specific year ranges (e.g., "between 2015 and 2020", "from 2010 to 2020", "over 2010-2020")
|
| 241 |
-
match = re.search(r'(?:average|growth|performance|return).*?(?:between|from|over|for|through)\s*(\d{4})\s*(?:and|to|-|–|through)\s*(\d{4})', query)
|
| 242 |
-
if match:
|
| 243 |
-
start_year, end_year = map(int, match.groups())
|
| 244 |
-
if start_year <= end_year:
|
| 245 |
-
return start_year, end_year, None, symbol
|
| 246 |
-
# Match duration-based queries (e.g., "5-year from 2020", "7-year growth rate from 2018")
|
| 247 |
-
match = re.search(r'(\d+)-year.*from\s*(\d{4})', query)
|
| 248 |
-
if match:
|
| 249 |
-
duration, start_year = map(int, match.groups())
|
| 250 |
-
end_year = start_year + duration - 1
|
| 251 |
-
return start_year, end_year, duration, symbol
|
| 252 |
-
# Match past/last X years (e.g., "past 5 years", "last 7 years return", "over the last 5 years")
|
| 253 |
-
match = re.search(r'(?:past|last|over\s*the\s*last)\s*(\d+)\s*(?:year|years).*?(?:return|growth|performance)?', query)
|
| 254 |
-
if match:
|
| 255 |
-
duration = int(match.group(1))
|
| 256 |
-
max_year = datetime.now().year
|
| 257 |
-
start_year = max_year - duration + 1
|
| 258 |
-
end_year = max_year
|
| 259 |
-
return start_year, end_year, duration, symbol
|
| 260 |
-
# Match single year (e.g., "return in 2020", "performance for 2019")
|
| 261 |
-
match = re.search(r'(?:return|performance)\s*(?:in|for)\s*(\d{4})', query)
|
| 262 |
-
if match:
|
| 263 |
-
year = int(match.group(1))
|
| 264 |
-
return year, year, 1, symbol
|
| 265 |
-
return None, None, None, symbol
|
| 266 |
-
|
| 267 |
-
# Calculate average growth rate using CAGR
|
| 268 |
-
def calculate_growth_rate(start_year, end_year, duration=None, symbol="SPY"):
|
| 269 |
-
if df is None or start_year is None or end_year is None:
|
| 270 |
-
return None, "Data not available or invalid period."
|
| 271 |
-
df_period = df[(df['Date'].dt.year >= start_year) & (df['Date'].dt.year <= end_year)]
|
| 272 |
-
if df_period.empty:
|
| 273 |
-
return None, f"No data available for {symbol} from {start_year} to {end_year}."
|
| 274 |
-
initial_price = df_period[df_period['Date'].dt.year == start_year][f'Price_{symbol}'].mean()
|
| 275 |
-
final_price = df_period[df_period['Date'].dt.year == end_year][f'Price_{symbol}'].mean()
|
| 276 |
-
avg_dividend = df_period[f'Dividend_{symbol}'].mean()
|
| 277 |
-
avg_real_return = df_period[f"Real_Return_{symbol}"].mean()
|
| 278 |
-
if np.isnan(initial_price) or np.isnan(final_price) or initial_price == 0:
|
| 279 |
-
return None, f"Insufficient data for {symbol} from {start_year} to {end_year}."
|
| 280 |
-
num_years = end_year - start_year + 1
|
| 281 |
-
cagr = ((final_price / initial_price) ** (1 / num_years) - 1) * 100
|
| 282 |
-
symbol_name = "S&P 500" if symbol == "SPY" else symbol
|
| 283 |
-
if duration == 1 and start_year == end_year:
|
| 284 |
-
response = f"The {symbol_name} returned approximately {cagr:.1f}% in {start_year}, including dividends. Inflation-adjusted real return averaged {avg_real_return:.1f}%. Dividends contributed {avg_dividend:.1f}% to total returns."
|
| 285 |
-
elif duration:
|
| 286 |
-
response = f"The {symbol_name} {duration}-year compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends. Inflation-adjusted real return averaged {avg_real_return:.1f}%. Dividends contributed {avg_dividend:.1f}% to total returns."
|
| 287 |
-
else:
|
| 288 |
-
response = f"The {symbol_name} compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends. Inflation-adjusted real return averaged {avg_real_return:.1f}%. Dividends contributed {avg_dividend:.1f}% to total returns."
|
| 289 |
-
return cagr, response
|
| 290 |
-
|
| 291 |
-
# Parse investment return query
|
| 292 |
-
def parse_investment_query(query):
|
| 293 |
-
match = re.search(r'\$(\d+).*\s(\d+)\s*years?.*\b(tsla|msft|nvda|goog|amzn|s&p\s*500)\b', query, re.IGNORECASE)
|
| 294 |
-
if match:
|
| 295 |
-
amount = float(match.group(1))
|
| 296 |
-
years = int(match.group(2))
|
| 297 |
-
symbol = match.group(3).upper()
|
| 298 |
-
if symbol == "S&P 500":
|
| 299 |
-
symbol = "SPY"
|
| 300 |
-
return amount, years, symbol
|
| 301 |
-
return None, None, None
|
| 302 |
-
|
| 303 |
-
# Calculate future value
|
| 304 |
-
def calculate_future_value(amount, years, symbol):
|
| 305 |
-
if df_yearly is None or amount is None or years is None:
|
| 306 |
-
return None, "Data not available or invalid input."
|
| 307 |
-
avg_annual_return = 10.0
|
| 308 |
-
future_value = amount * (1 + avg_annual_return / 100) ** years
|
| 309 |
-
symbol_name = "S&P 500" if symbol == "SPY" else symbol
|
| 310 |
-
return future_value, (
|
| 311 |
-
f"Assuming a 10% average annual return, a ${amount:,.0f} investment in {symbol_name} would grow to approximately ${future_value:,.0f} "
|
| 312 |
-
f"in {years} years with annual compounding. This is based on the historical average return of 10–12% for stocks. "
|
| 313 |
-
"Future returns vary and are not guaranteed. Consult a financial planner."
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
# Chat function
|
| 317 |
-
def chat_with_model(user_input, history=None, is_processing=False):
|
| 318 |
-
try:
|
| 319 |
-
start_time = time.time()
|
| 320 |
-
logger.info(f"Processing user input: {user_input}")
|
| 321 |
-
is_processing = True
|
| 322 |
-
logger.info("Showing loading animation")
|
| 323 |
-
|
| 324 |
-
# Normalize and check cache
|
| 325 |
-
cache_key = re.sub(r'[^\w\s]', '', user_input.lower().strip()) # Normalize for cache
|
| 326 |
-
cache_keys = list(response_cache.keys())
|
| 327 |
-
closest_key = cache_key if cache_key in response_cache else get_closest_cache_key(cache_key, cache_keys)
|
| 328 |
-
if closest_key:
|
| 329 |
-
logger.info(f"Cache hit for: {closest_key}")
|
| 330 |
-
response = response_cache[closest_key]
|
| 331 |
-
logger.info(f"Chatbot response: {response}")
|
| 332 |
-
history = history or []
|
| 333 |
-
history.append({"role": "user", "content": user_input})
|
| 334 |
-
history.append({"role": "assistant", "content": response})
|
| 335 |
-
end_time = time.time()
|
| 336 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 337 |
-
return response, history, False, ""
|
| 338 |
-
|
| 339 |
-
# Check for investment return query
|
| 340 |
-
amount, years, symbol = parse_investment_query(user_input)
|
| 341 |
-
if amount and years:
|
| 342 |
-
future_value, response = calculate_future_value(amount, years, symbol)
|
| 343 |
-
if future_value is not None:
|
| 344 |
-
response_cache[cache_key] = response
|
| 345 |
-
logger.info(f"Investment query: ${amount} for {years} years in {symbol}, added to cache")
|
| 346 |
-
logger.info(f"Chatbot response: {response}")
|
| 347 |
-
history = history or []
|
| 348 |
-
history.append({"role": "user", "content": user_input})
|
| 349 |
-
history.append({"role": "assistant", "content": response})
|
| 350 |
-
end_time = time.time()
|
| 351 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 352 |
-
return response, history, False, ""
|
| 353 |
-
|
| 354 |
-
# Check for period-specific query
|
| 355 |
-
start_year, end_year, duration, symbol = parse_period(user_input)
|
| 356 |
-
if start_year and end_year:
|
| 357 |
-
avg_return, response = calculate_growth_rate(start_year, end_year, duration, symbol)
|
| 358 |
-
if avg_return is not None:
|
| 359 |
-
response_cache[cache_key] = response
|
| 360 |
-
logger.info(f"Dynamic period query for {symbol}: {start_year}–{end_year}, added to cache")
|
| 361 |
-
logger.info(f"Chatbot response: {response}")
|
| 362 |
-
history = history or []
|
| 363 |
-
history.append({"role": "user", "content": user_input})
|
| 364 |
-
history.append({"role": "assistant", "content": response})
|
| 365 |
-
end_time = time.time()
|
| 366 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 367 |
-
return response, history, False, ""
|
| 368 |
-
|
| 369 |
-
# Handle short prompts
|
| 370 |
-
if len(user_input.strip()) <= 5:
|
| 371 |
-
logger.info("Short prompt, returning default response")
|
| 372 |
-
response = "Hello! I'm FinChat, your financial advisor. Ask about investing in TSLA, MSFT, NVDA, GOOG, AMZN, or S&P 500!"
|
| 373 |
-
logger.info(f"Chatbot response: {response}")
|
| 374 |
-
history = history or []
|
| 375 |
-
history.append({"role": "user", "content": user_input})
|
| 376 |
-
history.append({"role": "assistant", "content": response})
|
| 377 |
-
end_time = time.time()
|
| 378 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 379 |
-
return response, history, False, ""
|
| 380 |
-
|
| 381 |
-
# Construct and generate response
|
| 382 |
-
full_prompt = prompt_prefix + user_input + "\nA:"
|
| 383 |
-
try:
|
| 384 |
-
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
| 385 |
-
except Exception as e:
|
| 386 |
-
logger.error(f"Error tokenizing input: {e}")
|
| 387 |
-
response = f"Error: Failed to process input: {str(e)}"
|
| 388 |
-
logger.info(f"Chatbot response: {response}")
|
| 389 |
-
history = history or []
|
| 390 |
-
history.append({"role": "user", "content": user_input})
|
| 391 |
-
history.append({"role": "assistant", "content": response})
|
| 392 |
-
end_time = time.time()
|
| 393 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 394 |
-
return response, history, False, ""
|
| 395 |
-
|
| 396 |
-
with torch.inference_mode():
|
| 397 |
-
logger.info("Generating response with model")
|
| 398 |
-
gen_start_time = time.time()
|
| 399 |
-
outputs = model.generate(
|
| 400 |
-
**inputs,
|
| 401 |
-
max_new_tokens=15, # Further reduced for speed
|
| 402 |
-
do_sample=False,
|
| 403 |
-
repetition_penalty=2.0,
|
| 404 |
-
pad_token_id=tokenizer.eos_token_id
|
| 405 |
-
# Removed num_beams for faster greedy decoding
|
| 406 |
-
)
|
| 407 |
-
gen_end_time = time.time()
|
| 408 |
-
logger.info(f"Generation time: {gen_end_time - gen_start_time:.2f} seconds")
|
| 409 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 410 |
-
response = response[len(full_prompt):].strip() if response.startswith(full_prompt) else response
|
| 411 |
-
response = response[:response.rfind('.') + 1] if len(response.split()) > 100 else response # Truncate if over 100 words
|
| 412 |
-
logger.info(f"Chatbot response: {response}")
|
| 413 |
-
|
| 414 |
-
# Update cache
|
| 415 |
-
response_cache[cache_key] = response
|
| 416 |
-
logger.info("Cache miss, added to in-memory cache")
|
| 417 |
-
|
| 418 |
-
# Update history
|
| 419 |
-
history = history or []
|
| 420 |
-
history.append({"role": "user", "content": user_input})
|
| 421 |
-
history.append({"role": "assistant", "content": response})
|
| 422 |
-
torch.cuda.empty_cache() # Though on CPU, harmless
|
| 423 |
-
end_time = time.time()
|
| 424 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 425 |
-
return response, history, False, ""
|
| 426 |
-
|
| 427 |
-
except Exception as e:
|
| 428 |
-
logger.error(f"Error generating response: {e}")
|
| 429 |
-
response = f"Error: {str(e)}"
|
| 430 |
-
logger.info(f"Chatbot response: {response}")
|
| 431 |
-
history = history or []
|
| 432 |
-
history.append({"role": "user", "content": user_input})
|
| 433 |
-
history.append({"role": "assistant", "content": response})
|
| 434 |
-
end_time = time.time()
|
| 435 |
-
logger.info(f"Response time: {end_time - start_time:.2f} seconds")
|
| 436 |
-
return response, history, False, ""
|
| 437 |
-
|
| 438 |
-
# Save cache
|
| 439 |
-
def save_cache():
|
| 440 |
-
try:
|
| 441 |
-
with open(cache_file, 'w') as f:
|
| 442 |
-
json.dump(response_cache, f, indent=2)
|
| 443 |
-
logger.info("Saved cache to cache.json")
|
| 444 |
-
except Exception as e:
|
| 445 |
-
logger.warning(f"Failed to save cache.json: {e}")
|
| 446 |
-
|
| 447 |
-
# Gradio interface
|
| 448 |
-
logger.info("Initializing Gradio interface")
|
| 449 |
-
try:
|
| 450 |
-
with gr.Blocks(
|
| 451 |
-
title="FinChat: An LLM based on distilgpt2 model",
|
| 452 |
-
css="""
|
| 453 |
-
.loader {
|
| 454 |
-
border: 5px solid #f3f3f3;
|
| 455 |
-
border-top: 5px solid #3498db;
|
| 456 |
-
border-radius: 50%;
|
| 457 |
-
width: 30px;
|
| 458 |
-
height: 30px;
|
| 459 |
-
animation: spin 1s linear infinite;
|
| 460 |
-
margin: 10px auto;
|
| 461 |
-
display: block;
|
| 462 |
-
}
|
| 463 |
-
@keyframes spin {
|
| 464 |
-
0% { transform: rotate(0deg); }
|
| 465 |
-
100% { transform: rotate(360deg); }
|
| 466 |
-
}
|
| 467 |
-
.hidden { display: none; }
|
| 468 |
-
"""
|
| 469 |
-
) as interface:
|
| 470 |
-
gr.Markdown(
|
| 471 |
-
"""
|
| 472 |
-
# FinChat: An LLM based on distilgpt2 model
|
| 473 |
-
FinChat provides financial advice using the lightweight distilgpt2 model, optimized for fast, detailed responses.
|
| 474 |
-
Ask about investing strategies, ETFs, or stocks like TSLA, MSFT, NVDA, GOOG, AMZN, or S&P 500 to get started!
|
| 475 |
-
"""
|
| 476 |
-
)
|
| 477 |
-
chatbot = gr.Chatbot(type="messages")
|
| 478 |
-
msg = gr.Textbox(label="Your message")
|
| 479 |
-
submit = gr.Button("Send")
|
| 480 |
-
clear = gr.Button("Clear")
|
| 481 |
-
loading = gr.HTML('<div class="loader hidden"></div>', label="Loading")
|
| 482 |
-
is_processing = gr.State(value=False)
|
| 483 |
-
|
| 484 |
-
def submit_message(user_input, history, is_processing):
|
| 485 |
-
response, updated_history, new_processing, clear_input = chat_with_model(user_input, history, is_processing)
|
| 486 |
-
loader_html = '<div class="loader"></div>' if new_processing else '<div class="loader hidden"></div>'
|
| 487 |
-
return clear_input, updated_history, loader_html, new_processing
|
| 488 |
-
|
| 489 |
-
submit.click(
|
| 490 |
-
fn=submit_message,
|
| 491 |
-
inputs=[msg, chatbot, is_processing],
|
| 492 |
-
outputs=[msg, chatbot, loading, is_processing]
|
| 493 |
-
)
|
| 494 |
-
clear.click(
|
| 495 |
-
fn=lambda: ("", [], '<div class="loader hidden"></div>', False),
|
| 496 |
-
outputs=[msg, chatbot, loading, is_processing]
|
| 497 |
-
)
|
| 498 |
-
logger.info("Gradio interface initialized successfully")
|
| 499 |
-
except Exception as e:
|
| 500 |
-
logger.error(f"Error initializing Gradio interface: {e}")
|
| 501 |
-
raise
|
| 502 |
|
| 503 |
-
#
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
finally:
|
| 512 |
-
save_cache()
|
| 513 |
-
else:
|
| 514 |
-
logger.info("Running in Hugging Face Spaces, interface defined but not launched")
|
| 515 |
-
import atexit
|
| 516 |
-
atexit.register(save_cache)
|
| 517 |
-
```
|
| 518 |
-
|
| 519 |
-
```python
|
| 520 |
-
# finetuned_model.py
|
| 521 |
-
import pandas as pd
|
| 522 |
-
import yfinance as yf
|
| 523 |
-
import requests
|
| 524 |
-
from fredapi import Fred
|
| 525 |
-
from datetime import datetime, timedelta
|
| 526 |
-
import numpy as np
|
| 527 |
-
import json
|
| 528 |
-
from datasets import Dataset
|
| 529 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
|
| 530 |
-
import torch
|
| 531 |
-
import logging
|
| 532 |
-
import itertools
|
| 533 |
-
|
| 534 |
-
# Set up logging
|
| 535 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 536 |
-
logger = logging.getLogger(__name__)
|
| 537 |
-
|
| 538 |
-
# Configuration
|
| 539 |
-
STOCK_SYMBOLS = ["TSLA", "MSFT", "NVDA", "GOOG", "AMZN", "SPY"]
|
| 540 |
-
START_DATE = "2010-01-01" # Expanded range
|
| 541 |
-
END_DATE = datetime.now().strftime("%Y-%m-%d") # Use current date
|
| 542 |
-
FRED_API_KEY = "your_fred_api_key" # Replace with your FRED API key
|
| 543 |
-
OUTPUT_CSV = "stock_data.csv"
|
| 544 |
-
MODEL_NAME = "distilgpt2"
|
| 545 |
-
OUTPUT_DIR = "./finetuned_model"
|
| 546 |
-
CACHE_FILE = "cache.json" # Save QA pairs to cache for faster responses
|
| 547 |
-
|
| 548 |
-
# Initialize FRED API
|
| 549 |
-
try:
|
| 550 |
-
fred = Fred(api_key=FRED_API_KEY)
|
| 551 |
-
logger.info("Initialized FRED API")
|
| 552 |
-
except Exception as e:
|
| 553 |
-
logger.error(f"Error initializing FRED API: {e}")
|
| 554 |
-
fred = None
|
| 555 |
-
|
| 556 |
-
def fetch_cpi_data():
|
| 557 |
-
"""Fetch CPI data from FRED for inflation adjustment."""
|
| 558 |
-
if fred is None:
|
| 559 |
-
logger.warning("FRED API not available; skipping CPI data")
|
| 560 |
-
return None
|
| 561 |
-
try:
|
| 562 |
-
cpi = fred.get_series("CPIAUCSL", start_date=START_DATE, end_date=END_DATE)
|
| 563 |
-
cpi = cpi.resample("M").last().ffill()
|
| 564 |
-
cpi_df = pd.DataFrame(cpi, columns=["CPI"])
|
| 565 |
-
cpi_df.index.name = "Date"
|
| 566 |
-
return cpi_df
|
| 567 |
-
except Exception as e:
|
| 568 |
-
logger.error(f"Error fetching CPI data: {e}")
|
| 569 |
-
return None
|
| 570 |
-
|
| 571 |
-
def fetch_stock_data(symbol):
|
| 572 |
-
"""Fetch historical price, dividend, and earnings data using yfinance."""
|
| 573 |
-
try:
|
| 574 |
-
ticker = yf.Ticker(symbol)
|
| 575 |
-
df = ticker.history(start=START_DATE, end=END_DATE, interval="1mo")
|
| 576 |
-
if df.empty:
|
| 577 |
-
raise ValueError(f"No data returned for {symbol}")
|
| 578 |
-
|
| 579 |
-
df = df[["Close", "Dividends"]].copy()
|
| 580 |
-
df.rename(columns={"Close": f"Price_{symbol}", "Dividends": f"Dividend_{symbol}"}, inplace=True)
|
| 581 |
-
df.index = pd.to_datetime(df.index)
|
| 582 |
-
|
| 583 |
-
try:
|
| 584 |
-
earnings = ticker.financials.loc["Net Income"].mean() / ticker.info.get("sharesOutstanding", 1)
|
| 585 |
-
df[f"Earnings_{symbol}"] = earnings
|
| 586 |
-
except Exception:
|
| 587 |
-
logger.warning(f"Earnings data unavailable for {symbol}; setting to 0")
|
| 588 |
-
df[f"Earnings_{symbol}"] = 0.0
|
| 589 |
-
|
| 590 |
-
return df
|
| 591 |
-
except Exception as e:
|
| 592 |
-
logger.error(f"Error fetching stock data for {symbol}: {e}")
|
| 593 |
-
return None
|
| 594 |
-
|
| 595 |
-
def calculate_pe10(price, earnings):
|
| 596 |
-
"""Calculate PE10 (CAPE) ratio using price and average earnings."""
|
| 597 |
-
if earnings == 0:
|
| 598 |
-
return 0.0
|
| 599 |
-
try:
|
| 600 |
-
pe10 = price / earnings if earnings > 0 else 0.0
|
| 601 |
-
return pe10
|
| 602 |
-
except Exception as e:
|
| 603 |
-
logger.warning(f"Error calculating PE10: {e}")
|
| 604 |
-
return 0.0
|
| 605 |
-
|
| 606 |
-
def adjust_for_inflation(df, cpi_df, symbol):
|
| 607 |
-
"""Adjust prices for inflation using CPI data."""
|
| 608 |
-
if cpi_df is None:
|
| 609 |
-
logger.warning(f"CPI data unavailable for {symbol}; Real Price set to Price")
|
| 610 |
-
df[f"Real_Price_{symbol}"] = df[f"Price_{symbol}"]
|
| 611 |
-
return df
|
| 612 |
-
|
| 613 |
-
try:
|
| 614 |
-
cpi_df = cpi_df.reindex(df.index, method="ffill")
|
| 615 |
-
latest_cpi = cpi_df["CPI"].iloc[-1]
|
| 616 |
-
df[f"Real_Price_{symbol}"] = df[f"Price_{symbol}"] * (latest_cpi / cpi_df["CPI"])
|
| 617 |
-
return df
|
| 618 |
-
except Exception as e:
|
| 619 |
-
logger.error(f"Error adjusting for inflation for {symbol}: {e}")
|
| 620 |
-
df[f"Real_Price_{symbol}"] = df[f"Price_{symbol}"]
|
| 621 |
-
return df
|
| 622 |
-
|
| 623 |
-
def create_dataset(symbols):
|
| 624 |
-
"""Create a combined dataset for all specified stocks/indices."""
|
| 625 |
-
cpi_df = fetch_cpi_data()
|
| 626 |
-
all_dfs = []
|
| 627 |
-
|
| 628 |
-
for symbol in symbols:
|
| 629 |
-
logger.info(f"Fetching data for {symbol}")
|
| 630 |
-
df = fetch_stock_data(symbol)
|
| 631 |
-
if df is None or df.empty:
|
| 632 |
-
logger.error(f"Skipping {symbol} due to data fetch failure")
|
| 633 |
-
continue
|
| 634 |
-
|
| 635 |
-
df = adjust_for_inflation(df, cpi_df, symbol)
|
| 636 |
-
df[f"Return_{symbol}"] = df[f"Price_{symbol}"].pct_change(12) * 100
|
| 637 |
-
df[f"Real_Return_{symbol}"] = df[f"Real_Price_{symbol}"].pct_change(12) * 100
|
| 638 |
-
df[f"PE10_{symbol}"] = df.apply(lambda row: calculate_pe10(row[f"Price_{symbol}"], row[f"Earnings_{symbol}"]), axis=1)
|
| 639 |
-
|
| 640 |
-
df[[f"Return_{symbol}", f"Real_Return_{symbol}", f"Dividend_{symbol}", f"Earnings_{symbol}", f"PE10_{symbol}"]] = \
|
| 641 |
-
df[[f"Return_{symbol}", f"Real_Return_{symbol}", f"Dividend_{symbol}", f"Earnings_{symbol}", f"PE10_{symbol}"]].fillna(0.0)
|
| 642 |
-
|
| 643 |
-
all_dfs.append(df)
|
| 644 |
-
|
| 645 |
-
if not all_dfs:
|
| 646 |
-
logger.error("No data fetched for any symbol")
|
| 647 |
-
return None
|
| 648 |
-
|
| 649 |
-
combined_df = all_dfs[0]
|
| 650 |
-
for df in all_dfs[1:]:
|
| 651 |
-
combined_df = combined_df.join(df, how="outer")
|
| 652 |
-
|
| 653 |
-
combined_df.reset_index(inplace=True)
|
| 654 |
-
return combined_df
|
| 655 |
-
|
| 656 |
-
def save_dataset(df, output_path):
|
| 657 |
-
"""Save dataset to CSV."""
|
| 658 |
-
if df is not None:
|
| 659 |
-
try:
|
| 660 |
-
df.to_csv(output_path, index=False)
|
| 661 |
-
logger.info(f"Dataset saved to {output_path}")
|
| 662 |
-
except Exception as e:
|
| 663 |
-
logger.error(f"Error saving dataset: {e}")
|
| 664 |
-
|
| 665 |
-
# Step 1: Create and Save Dataset
|
| 666 |
-
logger.info(f"Creating dataset for {STOCK_SYMBOLS}")
|
| 667 |
-
df = create_dataset(STOCK_SYMBOLS)
|
| 668 |
-
if df is None:
|
| 669 |
-
logger.error("Dataset creation failed")
|
| 670 |
-
exit()
|
| 671 |
-
save_dataset(df, OUTPUT_CSV)
|
| 672 |
-
|
| 673 |
-
# Step 2: Preprocess Dataset for Training
|
| 674 |
-
df['Date'] = pd.to_datetime(df['Date'])
|
| 675 |
-
df_yearly = df.groupby(df['Date'].dt.year).mean().reset_index()
|
| 676 |
-
df_yearly = df_yearly.rename(columns={'Date': 'Year'})
|
| 677 |
-
|
| 678 |
-
# Step 3: Create Question-Answer Pairs with enhancements
|
| 679 |
-
qa_pairs = []
|
| 680 |
-
years = df_yearly['Year'].unique()
|
| 681 |
-
min_year = int(years.min())
|
| 682 |
-
max_year = int(years.max())
|
| 683 |
-
|
| 684 |
-
for symbol in STOCK_SYMBOLS:
|
| 685 |
-
for _, row in df_yearly.iterrows():
|
| 686 |
-
year = int(row['Year'])
|
| 687 |
-
price = row.get(f"Price_{symbol}", 0.0)
|
| 688 |
-
dividend = row.get(f"Dividend_{symbol}", 0.0)
|
| 689 |
-
earnings = row.get(f"Earnings_{symbol}", 0.0)
|
| 690 |
-
return_val = row.get(f"Return_{symbol}", 0.0)
|
| 691 |
-
real_return = row.get(f"Real_Return_{symbol}", 0.0)
|
| 692 |
-
pe10 = row.get(f"PE10_{symbol}", 0.0)
|
| 693 |
-
|
| 694 |
-
symbol_name = "S&P 500" if symbol == "SPY" else symbol
|
| 695 |
-
|
| 696 |
-
qa_pairs.append({
|
| 697 |
-
"question": f"What was the {symbol_name} return in {year}?",
|
| 698 |
-
"answer": f"The {symbol_name} returned approximately {return_val:.1f}% in {year}, including dividends."
|
| 699 |
-
})
|
| 700 |
-
qa_pairs.append({
|
| 701 |
-
"question": f"What was the {symbol_name} price in {year}?",
|
| 702 |
-
"answer": f"The {symbol_name} averaged approximately {price:.2f} in {year}."
|
| 703 |
-
})
|
| 704 |
-
qa_pairs.append({
|
| 705 |
-
"question": f"What was the {symbol_name} real return in {year}?",
|
| 706 |
-
"answer": f"The {symbol_name} inflation-adjusted return was approximately {real_return:.1f}% in {year}."
|
| 707 |
-
})
|
| 708 |
-
if dividend > 0:
|
| 709 |
-
qa_pairs.append({
|
| 710 |
-
"question": f"What was the {symbol_name} dividend in {year}?",
|
| 711 |
-
"answer": f"The {symbol_name} dividend was approximately {dividend:.2f} in {year}."
|
| 712 |
-
})
|
| 713 |
-
if earnings > 0:
|
| 714 |
-
qa_pairs.append({
|
| 715 |
-
"question": f"What were the {symbol_name} earnings in {year}?",
|
| 716 |
-
"answer": f"The {symbol_name} earnings were approximately {earnings:.2f} in {year}."
|
| 717 |
-
})
|
| 718 |
-
if pe10 > 0:
|
| 719 |
-
qa_pairs.append({
|
| 720 |
-
"question": f"What was the {symbol_name} PE10 ratio in {year}?",
|
| 721 |
-
"answer": f"The {symbol_name} PE10 ratio was approximately {pe10:.2f} in {year}."
|
| 722 |
-
})
|
| 723 |
-
qa_pairs.append({
|
| 724 |
-
"summary": f"In {year}, the {symbol_name} averaged {price:.2f} with a {return_val:.1f}% annual return and a {real_return:.1f}% real return."
|
| 725 |
-
})
|
| 726 |
-
|
| 727 |
-
# Period-specific questions with CAGR
|
| 728 |
-
for start_year, end_year in itertools.combinations(years, 2):
|
| 729 |
-
if start_year < end_year:
|
| 730 |
-
df_period = df_yearly[(df_yearly['Year'] >= start_year) & (df_yearly['Year'] <= end_year)]
|
| 731 |
-
if not df_period.empty:
|
| 732 |
-
initial_price = df_period[df_period['Year'] == start_year][f'Price_{symbol}'].mean()
|
| 733 |
-
final_price = df_period[df_period['Year'] == end_year][f'Price_{symbol}'].mean()
|
| 734 |
-
num_years = end_year - start_year + 1
|
| 735 |
-
cagr = ((final_price / initial_price) ** (1 / num_years) - 1) * 100 if initial_price > 0 else 0.0
|
| 736 |
-
avg_real_return = df_period[f"Real_Return_{symbol}"].mean()
|
| 737 |
-
qa_pairs.append({
|
| 738 |
-
"question": f"What was the average annual growth rate of {symbol_name} between {start_year} and {end_year}?",
|
| 739 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends. This accounts for compounding, unlike simple averages, and includes market volatility risks."
|
| 740 |
-
})
|
| 741 |
-
qa_pairs.append({
|
| 742 |
-
"question": f"What was the average annual return of {symbol_name} between {start_year} and {end_year}?",
|
| 743 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends."
|
| 744 |
-
})
|
| 745 |
-
qa_pairs.append({
|
| 746 |
-
"question": f"What was the {symbol_name} real return between {start_year} and {end_year}?",
|
| 747 |
-
"answer": f"The {symbol_name} average annual inflation-adjusted return from {start_year} to {end_year} was approximately {avg_real_return:.1f}%."
|
| 748 |
-
})
|
| 749 |
-
qa_pairs.append({
|
| 750 |
-
"question": f"What was the {num_years}-year average annual growth rate of {symbol_name} from {start_year}?",
|
| 751 |
-
"answer": f"The {symbol_name} {num_years}-year compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends."
|
| 752 |
-
})
|
| 753 |
-
qa_pairs.append({
|
| 754 |
-
"question": f"What was the inflation-adjusted return for {symbol_name} from {start_year} to {end_year}?",
|
| 755 |
-
"answer": f"The {symbol_name} average annual inflation-adjusted return from {start_year} to {end_year} was approximately {avg_real_return:.1f}%. This matters in high-inflation periods to reflect true purchasing power."
|
| 756 |
-
})
|
| 757 |
-
qa_pairs.append({
|
| 758 |
-
"question": f"Explain the return for {symbol_name} between {start_year} and {end_year}",
|
| 759 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends. Compared to S&P 500's 10–12% average, this shows relative performance but with greater volatility."
|
| 760 |
-
})
|
| 761 |
-
|
| 762 |
-
# Past X years questions with more variations to reduce hallucinations
|
| 763 |
-
for duration in range(1, max_year - min_year + 2):
|
| 764 |
-
for end_year in years:
|
| 765 |
-
start_year = end_year - duration + 1
|
| 766 |
-
if start_year >= min_year:
|
| 767 |
-
df_period = df_yearly[(df_yearly['Year'] >= start_year) & (df_yearly['Year'] <= end_year)]
|
| 768 |
-
if not df_period.empty:
|
| 769 |
-
initial_price = df_period[df_period['Year'] == start_year][f'Price_{symbol}'].mean()
|
| 770 |
-
final_price = df_period[df_period['Year'] == end_year][f'Price_{symbol}'].mean()
|
| 771 |
-
num_years = end_year - start_year + 1
|
| 772 |
-
cagr = ((final_price / initial_price) ** (1 / num_years) - 1) * 100 if initial_price > 0 else 0.0
|
| 773 |
-
avg_real_return = df_period[f"Real_Return_{symbol}"].mean()
|
| 774 |
-
qa_pairs.append({
|
| 775 |
-
"question": f"What was the average annual growth rate of {symbol_name} in the past {duration} years from {end_year}?",
|
| 776 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends."
|
| 777 |
-
})
|
| 778 |
-
qa_pairs.append({
|
| 779 |
-
"question": f"What was the {duration}-year average annual growth rate of {symbol_name} ending in {end_year}?",
|
| 780 |
-
"answer": f"The {symbol_name} {duration}-year compounded annual growth rate (CAGR) from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends."
|
| 781 |
-
})
|
| 782 |
-
qa_pairs.append({
|
| 783 |
-
"question": f"What is the average return of {symbol_name} over the last {duration} years?",
|
| 784 |
-
"answer": f"The average annual return of {symbol_name} from {start_year} to {end_year} was approximately {cagr:.1f}%, including dividends."
|
| 785 |
-
})
|
| 786 |
-
qa_pairs.append({
|
| 787 |
-
"question": f"What was {symbol_name}'s performance in the past {duration} years?",
|
| 788 |
-
"answer": f"{symbol_name} had a compounded annual growth rate (CAGR) of approximately {cagr:.1f}% from {start_year} to {end_year}, including dividends."
|
| 789 |
-
})
|
| 790 |
-
qa_pairs.append({
|
| 791 |
-
"question": f"Calculate the average annual return for {symbol_name} in the last {duration} years.",
|
| 792 |
-
"answer": f"The calculated compounded annual growth rate (CAGR) for {symbol_name} from {start_year} to {end_year} is approximately {cagr:.1f}%, including dividends."
|
| 793 |
-
})
|
| 794 |
-
qa_pairs.append({
|
| 795 |
-
"question": f"What was {symbol_name}'s volatility in the past {duration} years?",
|
| 796 |
-
"answer": f"{symbol_name}'s returns from {start_year} to {end_year} show high volatility typical of tech stocks; CAGR was {cagr:.1f}%, but diversify to mitigate risks."
|
| 797 |
-
})
|
| 798 |
-
|
| 799 |
-
# Investment return questions
|
| 800 |
-
amounts = [1000, 5000, 10000]
|
| 801 |
-
durations = [1, 3, 5, 10, 20]
|
| 802 |
-
avg_annual_return = 10.0
|
| 803 |
-
for symbol in STOCK_SYMBOLS:
|
| 804 |
-
symbol_name = "S&P 500" if symbol == "SPY" else symbol
|
| 805 |
-
for amount in amounts:
|
| 806 |
-
for n in durations:
|
| 807 |
-
future_value = amount * (1 + avg_annual_return / 100) ** n
|
| 808 |
-
qa_pairs.append({
|
| 809 |
-
"question": f"What will ${amount} be worth in {n} years if invested in {symbol_name}?",
|
| 810 |
-
"answer": f"Assuming a 10% average annual return, ${amount:,.0f} invested in {symbol_name} would grow to approximately ${future_value:,.0f} in {n} years with annual compounding."
|
| 811 |
-
})
|
| 812 |
-
|
| 813 |
-
# General questions with nuances
|
| 814 |
-
for symbol in STOCK_SYMBOLS:
|
| 815 |
-
symbol_name = "S&P 500" if symbol == "SPY" else symbol
|
| 816 |
-
df_10yr = df_yearly[(df_yearly['Year'] >= max_year-10) & (df_yearly['Year'] <= max_year)]
|
| 817 |
-
initial_10 = df_10yr[df_10yr['Year'] == max_year-10][f'Price_{symbol}'].mean()
|
| 818 |
-
final_10 = df_10yr[df_10yr['Year'] == max_year][f'Price_{symbol}'].mean()
|
| 819 |
-
cagr_10 = ((final_10 / initial_10) ** (1 / 10) - 1) * 100 if initial_10 > 0 else 0.0
|
| 820 |
-
qa_pairs.append({
|
| 821 |
-
"question": f"What is the average return rate of {symbol_name} in the past 10 years?",
|
| 822 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {max_year-10} to {max_year} was approximately {cagr_10:.1f}%, including dividends."
|
| 823 |
-
})
|
| 824 |
-
df_5yr = df_yearly[(df_yearly['Year'] >= max_year-5) & (df_yearly['Year'] <= max_year)]
|
| 825 |
-
initial_5 = df_5yr[df_5yr['Year'] == max_year-5][f'Price_{symbol}'].mean()
|
| 826 |
-
final_5 = df_5yr[df_5yr['Year'] == max_year][f'Price_{symbol}'].mean()
|
| 827 |
-
cagr_5 = ((final_5 / initial_5) ** (1 / 5) - 1) * 100 if initial_5 > 0 else 0.0
|
| 828 |
-
qa_pairs.append({
|
| 829 |
-
"question": f"What is the average return rate of {symbol_name} in the last 5 years?",
|
| 830 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {max_year-5} to {max_year} was approximately {cagr_5:.1f}%, including dividends."
|
| 831 |
-
})
|
| 832 |
-
df_7yr = df_yearly[(df_yearly['Year'] >= max_year-7) & (df_yearly['Year'] <= max_year)]
|
| 833 |
-
initial_7 = df_7yr[df_7yr['Year'] == max_year-7][f'Price_{symbol}'].mean()
|
| 834 |
-
final_7 = df_7yr[df_7yr['Year'] == max_year][f'Price_{symbol}'].mean()
|
| 835 |
-
cagr_7 = ((final_7 / initial_7) ** (1 / 7) - 1) * 100 if initial_7 > 0 else 0.0
|
| 836 |
-
qa_pairs.append({
|
| 837 |
-
"question": f"What is the average return rate of {symbol_name} in the past 7 years?",
|
| 838 |
-
"answer": f"The {symbol_name} compounded annual growth rate (CAGR) from {max_year-7} to {max_year} was approximately {cagr_7:.1f}%, including dividends."
|
| 839 |
-
})
|
| 840 |
-
qa_pairs.append({
|
| 841 |
-
"question": "What is the average growth rate for stocks?",
|
| 842 |
-
"answer": "The average annual return for individual stocks varies widely, but broad market indices like the S&P 500 average 10–12% over the long term (1927–2025), including dividends. Specific stocks like TSLA or NVDA may have higher volatility and returns."
|
| 843 |
-
})
|
| 844 |
-
|
| 845 |
-
# Save QA pairs to cache.json for pre-populated cache
|
| 846 |
-
cache_dict = {pair["question"].lower(): pair["answer"] for pair in qa_pairs if "question" in pair and "answer" in pair}
|
| 847 |
-
try:
|
| 848 |
-
with open(CACHE_FILE, 'w') as f:
|
| 849 |
-
json.dump(cache_dict, f, indent=2)
|
| 850 |
-
logger.info(f"Saved {len(cache_dict)} QA pairs to {CACHE_FILE} for caching")
|
| 851 |
-
except Exception as e:
|
| 852 |
-
logger.warning(f"Failed to save {CACHE_FILE}: {e}")
|
| 853 |
-
|
| 854 |
-
# Save to JSON for dataset
|
| 855 |
-
with open("financial_data.json", "w") as f:
|
| 856 |
-
json.dump(qa_pairs, f, indent=2)
|
| 857 |
-
|
| 858 |
-
# Step 4: Load and Tokenize Dataset
|
| 859 |
-
dataset = Dataset.from_json("financial_data.json")
|
| 860 |
-
dataset = dataset.train_test_split(test_size=0.2, seed=42)
|
| 861 |
-
train_dataset = dataset["train"]
|
| 862 |
-
val_dataset = dataset["test"].train_test_split(test_size=0.5, seed=42)["train"]
|
| 863 |
-
test_dataset = dataset["test"].train_test_split(test_size=0.5, seed=42)["test"]
|
| 864 |
-
|
| 865 |
-
# Step 5: Load Model and Tokenizer
|
| 866 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 867 |
-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
| 868 |
-
|
| 869 |
-
def tokenize_function(examples):
|
| 870 |
-
inputs = []
|
| 871 |
-
for ex in zip(examples.get("question", []), examples.get("answer", []), examples.get("summary", [])):
|
| 872 |
-
if ex[0] and ex[1]:
|
| 873 |
-
inputs.append(ex[0] + " A: " + ex[1])
|
| 874 |
-
elif ex[2]:
|
| 875 |
-
inputs.append(ex[2])
|
| 876 |
-
return tokenizer(inputs, padding="max_length", truncation=True, max_length=512)
|
| 877 |
-
|
| 878 |
-
tokenized_train = train_dataset.map(tokenize_function, batched=True)
|
| 879 |
-
tokenized_val = val_dataset.map(tokenize_function, batched=True)
|
| 880 |
-
tokenized_test = test_dataset.map(tokenize_function, batched=True)
|
| 881 |
-
|
| 882 |
-
# Step 6: Load and Fine-Tune Model
|
| 883 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 884 |
-
MODEL_NAME,
|
| 885 |
-
torch_dtype=torch.float32, # Changed to float32 to avoid Half/Float issues during training
|
| 886 |
-
low_cpu_mem_usage=True
|
| 887 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
|
|
|
|
| 889 |
training_args = TrainingArguments(
|
| 890 |
-
output_dir=
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
|
|
|
|
|
|
|
|
|
| 896 |
weight_decay=0.01,
|
| 897 |
-
|
| 898 |
-
save_strategy="epoch",
|
| 899 |
-
load_best_model_at_end=True,
|
| 900 |
-
metric_for_best_model="eval_loss",
|
| 901 |
-
fp16=False # Disabled fp16 for CPU compatibility
|
| 902 |
)
|
| 903 |
|
| 904 |
-
|
|
|
|
| 905 |
model=model,
|
| 906 |
args=training_args,
|
| 907 |
-
train_dataset=
|
| 908 |
-
|
|
|
|
|
|
|
|
|
|
| 909 |
)
|
| 910 |
|
| 911 |
-
# Step 7: Train and Evaluate
|
| 912 |
trainer.train()
|
| 913 |
-
eval_results = trainer.evaluate(tokenized_test)
|
| 914 |
-
logger.info(f"Evaluation results: {eval_results}")
|
| 915 |
|
| 916 |
-
#
|
| 917 |
-
trainer.
|
| 918 |
-
tokenizer.save_pretrained(
|
| 919 |
-
logger.info(f"Model and tokenizer saved to {OUTPUT_DIR}")
|
| 920 |
-
```
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
| 2 |
+
from peft import LoraConfig, get_peft_model
|
| 3 |
+
from trl import SFTTrainer
|
| 4 |
+
from datasets import load_dataset
|
|
|
|
| 5 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Load model and tokenizer
|
| 8 |
+
model_name = "HuggingFaceTB/SmolLM3-3B"
|
| 9 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 11 |
+
model_name,
|
| 12 |
+
device_map="auto",
|
| 13 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# Prepare PEFT config for efficient fine-tuning
|
| 17 |
+
peft_config = LoraConfig(
|
| 18 |
+
r=16,
|
| 19 |
+
lora_alpha=32,
|
| 20 |
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| 21 |
+
lora_dropout=0.05,
|
| 22 |
+
bias="none",
|
| 23 |
+
task_type="CAUSAL_LM"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
+
model = get_peft_model(model, peft_config)
|
| 26 |
+
|
| 27 |
+
# Load dataset (example: assume 'financial_data.jsonl' with {'text': 'query ||| response'} format)
|
| 28 |
+
dataset = load_dataset("json", data_files="financial_data.jsonl", split="train")
|
| 29 |
|
| 30 |
+
# Training arguments
|
| 31 |
training_args = TrainingArguments(
|
| 32 |
+
output_dir="./finetuned_smollm3",
|
| 33 |
+
num_train_epochs=3,
|
| 34 |
+
per_device_train_batch_size=4,
|
| 35 |
+
gradient_accumulation_steps=4,
|
| 36 |
+
learning_rate=2e-4,
|
| 37 |
+
fp16=True if torch.cuda.is_available() else False,
|
| 38 |
+
save_steps=500,
|
| 39 |
+
logging_steps=100,
|
| 40 |
+
optim="paged_adamw_8bit",
|
| 41 |
weight_decay=0.01,
|
| 42 |
+
warmup_steps=100,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
)
|
| 44 |
|
| 45 |
+
# Trainer
|
| 46 |
+
trainer = SFTTrainer(
|
| 47 |
model=model,
|
| 48 |
args=training_args,
|
| 49 |
+
train_dataset=dataset,
|
| 50 |
+
peft_config=peft_config,
|
| 51 |
+
dataset_text_field="text", # Adjust based on your dataset
|
| 52 |
+
tokenizer=tokenizer,
|
| 53 |
+
max_seq_length=512,
|
| 54 |
)
|
| 55 |
|
|
|
|
| 56 |
trainer.train()
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
# Save fine-tuned model
|
| 59 |
+
trainer.model.save_pretrained("./finetuned_smollm3")
|
| 60 |
+
tokenizer.save_pretrained("./finetuned_smollm3")
|
|
|
|
|
|