finance-chatbot / utils.py
Ardaarslan02's picture
Upload 5 files
5ccc1f4 verified
raw
history blame
10.9 kB
# utils.py
import os
# Set CrewAI storage directory to something writable BEFORE any imports
os.environ["CREWAI_STORAGE_DIR"] = "/tmp/crewai"
os.environ["CREWAI_TELEMETRY_ENABLED"] = "false"
os.environ["CREWAI_DB_PATH"] = "/tmp/crewai/crewai.db"
os.environ["CREWAI_MEMORY_ENABLED"] = "false"
os.environ["CREWAI_TRACING_ENABLED"] = "false"
os.environ["CREWAI_AUTH_ENABLED"] = "false"
os.environ["CREWAI_CREDENTIALS_PATH"] = "/tmp/crewai/credentials"
os.environ["CREWAI_SECURE_STORAGE_PATH"] = "/tmp/crewai/secure"
# Set HuggingFace cache directories
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface_cache"
# Create cache directory if it doesn't exist
os.makedirs("/tmp/huggingface_cache", exist_ok=True)
from dotenv import load_dotenv
from langchain_qdrant import QdrantVectorStore
from langchain_community.embeddings import HuggingFaceEmbeddings
from crewai import Agent, Task, Crew, Process, LLM
import requests
from requests.exceptions import ConnectionError, Timeout, HTTPError
from functools import lru_cache
# Load environment variables from .env file
load_dotenv()
# Settings
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")
COLLECTION_NAME = "finance-chatbot"
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
SERPER_API_KEY = os.getenv("SERPER_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Initialize embeddings with a lightweight approach that doesn't require downloads
try:
# Try to use a very lightweight embedding model that's likely already cached
embeddings = HuggingFaceEmbeddings(
model_name='sentence-transformers/all-MiniLM-L6-v2',
cache_folder="/tmp",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
print("✅ Successfully loaded HuggingFace embeddings")
except Exception as e:
print(f"❌ HuggingFace embeddings failed: {e}")
# Fallback to a simple embedding approach
try:
from langchain.embeddings import FakeEmbeddings
embeddings = FakeEmbeddings(size=384)
print("⚠️ Using FakeEmbeddings as fallback - RAG functionality will be limited")
except Exception as e2:
print(f"❌ All embedding strategies failed: {e2}")
# Create a minimal embedding class
class SimpleEmbeddings:
def embed_documents(self, texts):
# Return random embeddings of size 384
import numpy as np
return [np.random.rand(384).tolist() for _ in texts]
def embed_query(self, text):
import numpy as np
return np.random.rand(384).tolist()
embeddings = SimpleEmbeddings()
print("⚠️ Using SimpleEmbeddings as final fallback - RAG functionality will be limited")
# Connect to the existing Qdrant collection
qdrant = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
collection_name=COLLECTION_NAME
)
# Initialize Mistral LLM
mistral_llm = LLM(model="mistral/mistral-large-latest", api_key=MISTRAL_API_KEY, temperature=0.7)
# Initialize Gemini LLM
gemini_llm = LLM(model="gemini/gemini-2.0-flash", api_key=GEMINI_API_KEY, temperature=0.7)
# Functions
@lru_cache(maxsize=100)
def search_qdrant(query, top_k=3):
"""Search Qdrant for relevant documents."""
try:
retriever = qdrant.as_retriever(search_type="similarity", search_kwargs={"k": top_k})
results = retriever.invoke(query)
return [{"text": doc.page_content, "source": doc.metadata.get("source", "Unknown")} for doc in results]
except Exception:
return []
def search_news(query, max_results=5):
"""Search for recent financial news using Serper API."""
try:
url = "https://google.serper.dev/search"
headers = {
"X-API-KEY": SERPER_API_KEY,
"Content-Type": "application/json"
}
payload = {
"q": f"{query} finance news",
"num": max_results
}
response = requests.post(url, json=payload, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
results = data.get("organic", [])
if not results:
return [{"title": "No recent news available", "url": "", "snippet": "Could not fetch news. Please try again later."}]
formatted_results = [
{
"title": item.get("title", ""),
"url": item.get("link", ""),
"snippet": item.get("snippet", "")
}
for item in results[:max_results]
]
return formatted_results
except ConnectionError:
return [{"title": "Connection Error", "url": "", "snippet": "Failed to connect to the news API. Please check your internet connection."}]
except Timeout:
return [{"title": "Timeout Error", "url": "", "snippet": "News API request timed out. Please try again later."}]
except HTTPError as e:
if response.status_code == 429:
return [{"title": "Rate Limit Exceeded", "url": "", "snippet": "Too many requests to the news API. Please try again later."}]
return [{"title": "HTTP Error", "url": "", "snippet": f"Failed to fetch news due to HTTP error: {e}"}]
except Exception:
return [{"title": "Error", "url": "", "snippet": "An unexpected error occurred while fetching news. Please try again later."}]
def get_stock_data(symbol):
"""Fetch stock data using Alpha Vantage API."""
try:
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={ALPHA_VANTAGE_API_KEY}"
response = requests.get(url, timeout=10)
response.raise_for_status()
data = response.json().get("Global Quote", {})
if not data:
return {"symbol": symbol, "error": "No data found for this symbol."}
return {
"symbol": symbol,
"price": data.get("05. price", "N/A"),
"change": data.get("09. change", "N/A"),
"change_percent": data.get("10. change percent", "N/A")
}
except ConnectionError:
return {"symbol": symbol, "error": "Failed to connect to the stock API. Please check your internet connection."}
except Timeout:
return {"symbol": symbol, "error": "Stock API request timed out. Please try again later."}
except HTTPError as e:
if response.status_code == 429:
return {"symbol": symbol, "error": "Too many requests to the stock API. Please try again later."}
return {"symbol": symbol, "error": f"Failed to fetch stock data due to HTTP error: {e}"}
except Exception:
return {"symbol": symbol, "error": "An unexpected error occurred while fetching stock data. Please try again later."}
@lru_cache(maxsize=100)
def determine_question_type(query):
"""Determine the type of user query using Mistral LLM via CrewAI's task mechanism."""
classifier_agent = Agent(
role="Query Classifier",
goal="Classify user queries into appropriate categories, including detecting out-of-scope queries.",
backstory="An expert in natural language understanding, capable of analyzing queries and categorizing them accurately.",
llm=mistral_llm,
verbose=True,
allow_delegation=False
)
# Check if the query is finance-related
finance_check_prompt = f"""
Analyze the following user query and determine if it is related to finance:
- Return 'Yes' if the query is related to financial terms, concepts, strategies, market news, or stock analysis (e.g., banking, stocks, revenue, P/E ratio).
- Return 'No' if the query is unrelated to finance (e.g., cooking recipes, weather, or unrelated topics).
Query: "{query}"
Provide your response in this format:
Is Finance Related: <Yes/No>
"""
finance_check_task = Task(
description=finance_check_prompt,
agent=classifier_agent,
expected_output="A classification in the format: Is Finance Related: <Yes/No>"
)
temp_crew = Crew(
agents=[classifier_agent],
tasks=[finance_check_task],
process=Process.sequential,
verbose=False
)
try:
response = temp_crew.kickoff()
response_text = response.raw if hasattr(response, 'raw') else str(response)
lines = response_text.strip().split("\n")
if len(lines) < 1 or "Is Finance Related:" not in lines[0]:
raise ValueError("Invalid response format from LLM for finance check")
is_finance_related = lines[0].replace("Is Finance Related: ", "").strip().lower() == "yes"
except Exception as e:
# Fallback to default behavior if classification fails
is_finance_related = False
if not is_finance_related:
return "out_of_scope", "This query is out of scope for a finance assistant."
#If finance-related, classify the query type
classification_prompt = f"""
Analyze the following user query and determine its category:
- finance_knowledge: General questions about financial terms, concepts, or strategies (e.g., 'What is revenue?', 'Explain P/E ratio')
- market_news: Questions about current market news, trends, or events (e.g., 'Latest news about cryptocurrency market')
- stock_analysis: Questions about specific stock analysis (e.g., mentioning a stock ticker like AAPL, 'Analyze META stock performance')
Query: "{query}"
Provide your response in this format:
Category: <category>
Extra Data: <additional info, such as the stock ticker for stock_analysis, or the query itself>
"""
classifier_task = Task(
description=classification_prompt,
agent=classifier_agent,
expected_output="A classification of the query in the format: Category: <category>\nExtra Data: <additional info>"
)
temp_crew = Crew(
agents=[classifier_agent],
tasks=[classifier_task],
process=Process.sequential,
verbose=False
)
try:
response = temp_crew.kickoff()
response_text = response.raw if hasattr(response, 'raw') else str(response)
lines = response_text.strip().split("\n")
if len(lines) < 2:
raise ValueError("Invalid response format from LLM for category classification")
category_line = lines[0].replace("Category: ", "").strip()
extra_data_line = lines[1].replace("Extra Data: ", "").strip()
if category_line not in ["finance_knowledge", "market_news", "stock_analysis"]:
raise ValueError(f"Invalid category: {category_line}")
return category_line, extra_data_line
except Exception:
return "finance_knowledge", query