AnilNiraula commited on
Commit
e2b2a4b
·
verified ·
1 Parent(s): dbf541c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -24
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import re
4
  from datetime import datetime, timedelta
5
  import difflib
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
  import torch
8
  import yfinance as yf
9
  from functools import lru_cache
@@ -36,7 +36,7 @@ def parse_period(query):
36
  return timedelta(weeks=num)
37
  elif unit == 'day':
38
  return timedelta(days=num)
39
- return timedelta(days=365) # Default to 1 year
40
 
41
  def find_closest_symbol(input_symbol):
42
  input_symbol = input_symbol.upper()
@@ -55,21 +55,16 @@ def calculate_growth_rate(start_date, end_date, symbol):
55
  if years == 0:
56
  return 0
57
  cagr = (1 + total_return) ** (1 / years) - 1
58
- return cagr * 100 # As percentage
59
 
60
  def calculate_investment(principal, years, annual_return=0.07):
61
  return principal * (1 + annual_return) ** years
62
 
63
- # Load SmolLM-135M-Instruct with 8-bit quantization
64
  model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
65
- quantization_config = BitsAndBytesConfig(
66
- load_in_8bit=True, # Switch to 8-bit for faster inference
67
- bnb_8bit_compute_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
68
- )
69
  tokenizer = AutoTokenizer.from_pretrained(model_name)
70
  model = AutoModelForCausalLM.from_pretrained(
71
  model_name,
72
- quantization_config=quantization_config,
73
  device_map="auto",
74
  )
75
 
@@ -78,7 +73,6 @@ def generate_response(user_query, enable_thinking=False):
78
  stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
79
  is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
80
  summary = ""
81
-
82
  if is_stock_query:
83
  # Parse query for symbol and period
84
  symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
@@ -87,22 +81,19 @@ def generate_response(user_query, enable_thinking=False):
87
  period = parse_period(user_query)
88
  end_date = datetime.now()
89
  start_date = end_date - period
90
-
91
  # Calculate growth rate
92
  growth_rate = calculate_growth_rate(start_date, end_date, symbol)
93
  if growth_rate is not None:
94
  summary = f"The CAGR for {symbol} over the period is {growth_rate:.2f}%."
95
  else:
96
  summary = f"No data available for {symbol} in the specified period."
97
-
98
  # Handle investment projection
99
  investment_match = re.search(r'\$(\d+)', user_query)
100
  if investment_match:
101
  principal = float(investment_match.group(1))
102
  years = period.days / 365.25
103
  projected = calculate_investment(principal, years)
104
- summary += f" Projecting ${principal} at 7% return over {years:.1f} years: ${projected:.2f}."
105
-
106
  # Prepare prompt
107
  system_prompt = (
108
  "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
@@ -112,44 +103,37 @@ def generate_response(user_query, enable_thinking=False):
112
  )
113
  messages = [
114
  {"role": "system", "content": system_prompt},
115
- {"role": "user", "content": f"{summary} {user_query}" if summary else user_query}
116
  ]
117
-
118
  text = tokenizer.apply_chat_template(
119
  messages,
120
  tokenize=False,
121
  add_generation_prompt=True,
122
  enable_thinking=enable_thinking
123
  )
124
-
125
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
126
  generated_ids = model.generate(
127
  **model_inputs,
128
- max_new_tokens=30, # Reduced for speed
129
  temperature=0.6,
130
  top_p=0.95,
131
  repetition_penalty=1.0,
132
  do_sample=False,
133
- early_stopping=True # Stop early for efficiency
134
  )
135
-
136
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
137
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
138
  return response.strip()
139
-
140
  # Gradio interface
141
  def chat(user_input, history):
142
  response = generate_response(user_input)
143
  history.append((user_input, response))
144
  return history, ""
145
-
146
  with gr.Blocks() as demo:
147
  gr.Markdown("# FinChat: AI-Powered Financial Advisor")
148
  chatbot = gr.Chatbot()
149
  msg = gr.Textbox(placeholder="Ask about stocks, investments, etc.")
150
  clear = gr.Button("Clear")
151
-
152
  msg.submit(chat, [msg, chatbot], [chatbot, msg])
153
  clear.click(lambda: None, None, chatbot, queue=False)
154
-
155
  demo.launch()
 
3
  import re
4
  from datetime import datetime, timedelta
5
  import difflib
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
  import torch
8
  import yfinance as yf
9
  from functools import lru_cache
 
36
  return timedelta(weeks=num)
37
  elif unit == 'day':
38
  return timedelta(days=num)
39
+ return timedelta(days=365) # Default to 1 year
40
 
41
  def find_closest_symbol(input_symbol):
42
  input_symbol = input_symbol.upper()
 
55
  if years == 0:
56
  return 0
57
  cagr = (1 + total_return) ** (1 / years) - 1
58
+ return cagr * 100 # As percentage
59
 
60
  def calculate_investment(principal, years, annual_return=0.07):
61
  return principal * (1 + annual_return) ** years
62
 
63
+ # Load SmolLM-135M-Instruct without quantization
64
  model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
 
 
 
 
65
  tokenizer = AutoTokenizer.from_pretrained(model_name)
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_name,
 
68
  device_map="auto",
69
  )
70
 
 
73
  stock_keywords = ['stock', 'growth', 'investment', 'price', 'return', 'cagr']
74
  is_stock_query = any(keyword in user_query.lower() for keyword in stock_keywords)
75
  summary = ""
 
76
  if is_stock_query:
77
  # Parse query for symbol and period
78
  symbol_match = re.search(r'\b([A-Z]{1,5})\b', user_query.upper())
 
81
  period = parse_period(user_query)
82
  end_date = datetime.now()
83
  start_date = end_date - period
 
84
  # Calculate growth rate
85
  growth_rate = calculate_growth_rate(start_date, end_date, symbol)
86
  if growth_rate is not None:
87
  summary = f"The CAGR for {symbol} over the period is {growth_rate:.2f}%."
88
  else:
89
  summary = f"No data available for {symbol} in the specified period."
 
90
  # Handle investment projection
91
  investment_match = re.search(r'\$(\d+)', user_query)
92
  if investment_match:
93
  principal = float(investment_match.group(1))
94
  years = period.days / 365.25
95
  projected = calculate_investment(principal, years)
96
+ summary += f" Projecting $ {principal} at 7% return over {years: .1f} years: $ {projected: .2f}."
 
97
  # Prepare prompt
98
  system_prompt = (
99
  "You are FinChat, a knowledgeable financial advisor. Always respond in a friendly, professional manner. "
 
103
  )
104
  messages = [
105
  {"role": "system", "content": system_prompt},
106
+ {"role": "user", "content": f" {summary} {user_query}" if summary else user_query}
107
  ]
 
108
  text = tokenizer.apply_chat_template(
109
  messages,
110
  tokenize=False,
111
  add_generation_prompt=True,
112
  enable_thinking=enable_thinking
113
  )
 
114
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
115
  generated_ids = model.generate(
116
  **model_inputs,
117
+ max_new_tokens=30, # Reduced for speed
118
  temperature=0.6,
119
  top_p=0.95,
120
  repetition_penalty=1.0,
121
  do_sample=False,
122
+ early_stopping=True # Stop early for efficiency
123
  )
 
124
  output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
125
  response = tokenizer.decode(output_ids, skip_special_tokens=True)
126
  return response.strip()
 
127
  # Gradio interface
128
  def chat(user_input, history):
129
  response = generate_response(user_input)
130
  history.append((user_input, response))
131
  return history, ""
 
132
  with gr.Blocks() as demo:
133
  gr.Markdown("# FinChat: AI-Powered Financial Advisor")
134
  chatbot = gr.Chatbot()
135
  msg = gr.Textbox(placeholder="Ask about stocks, investments, etc.")
136
  clear = gr.Button("Clear")
 
137
  msg.submit(chat, [msg, chatbot], [chatbot, msg])
138
  clear.click(lambda: None, None, chatbot, queue=False)
 
139
  demo.launch()