Spaces:
Runtime error
Runtime error
# utils.py | |
import os | |
# Set CrewAI storage directory to something writable BEFORE any imports | |
os.environ["CREWAI_STORAGE_DIR"] = "/tmp/crewai" | |
os.environ["CREWAI_TELEMETRY_ENABLED"] = "false" | |
from dotenv import load_dotenv | |
from langchain_qdrant import QdrantVectorStore | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from crewai import Agent, Task, Crew, Process, LLM | |
import requests | |
from requests.exceptions import ConnectionError, Timeout, HTTPError | |
from functools import lru_cache | |
# Load environment variables from .env file | |
load_dotenv() | |
# Settings | |
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") | |
QDRANT_URL = os.getenv("QDRANT_URL") | |
COLLECTION_NAME = "finance-chatbot" | |
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
ALPHA_VANTAGE_API_KEY = os.getenv("ALPHA_VANTAGE_API_KEY") | |
SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") | |
# Initialize embeddings with proper cache directory | |
import tempfile | |
import os | |
# Create a writable cache directory | |
cache_dir = "/tmp/huggingface_cache" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Set environment variables for HuggingFace cache | |
os.environ["HF_HOME"] = cache_dir | |
os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
os.environ["HF_HUB_CACHE"] = cache_dir | |
embeddings = HuggingFaceEmbeddings( | |
model_name='all-MiniLM-L6-v2', | |
cache_folder=cache_dir, | |
model_kwargs={'device': 'cpu'} # Force CPU usage in container | |
) | |
# Connect to the existing Qdrant collection | |
qdrant = QdrantVectorStore.from_existing_collection( | |
embedding=embeddings, | |
url=QDRANT_URL, | |
api_key=QDRANT_API_KEY, | |
collection_name=COLLECTION_NAME | |
) | |
# Initialize Mistral LLM | |
mistral_llm = LLM(model="mistral/mistral-large-latest", api_key=MISTRAL_API_KEY, temperature=0.7) | |
# Initialize Gemini LLM | |
gemini_llm = LLM(model="gemini/gemini-2.0-flash", api_key=GEMINI_API_KEY, temperature=0.7) | |
# Functions | |
def search_qdrant(query, top_k=3): | |
"""Search Qdrant for relevant documents.""" | |
try: | |
retriever = qdrant.as_retriever(search_type="similarity", search_kwargs={"k": top_k}) | |
results = retriever.invoke(query) | |
return [{"text": doc.page_content, "source": doc.metadata.get("source", "Unknown")} for doc in results] | |
except Exception: | |
return [] | |
def search_news(query, max_results=5): | |
"""Search for recent financial news using Serper API.""" | |
try: | |
url = "https://google.serper.dev/search" | |
headers = { | |
"X-API-KEY": SERPER_API_KEY, | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"q": f"{query} finance news", | |
"num": max_results | |
} | |
response = requests.post(url, json=payload, headers=headers, timeout=10) | |
response.raise_for_status() | |
data = response.json() | |
results = data.get("organic", []) | |
if not results: | |
return [{"title": "No recent news available", "url": "", "snippet": "Could not fetch news. Please try again later."}] | |
formatted_results = [ | |
{ | |
"title": item.get("title", ""), | |
"url": item.get("link", ""), | |
"snippet": item.get("snippet", "") | |
} | |
for item in results[:max_results] | |
] | |
return formatted_results | |
except ConnectionError: | |
return [{"title": "Connection Error", "url": "", "snippet": "Failed to connect to the news API. Please check your internet connection."}] | |
except Timeout: | |
return [{"title": "Timeout Error", "url": "", "snippet": "News API request timed out. Please try again later."}] | |
except HTTPError as e: | |
if response.status_code == 429: | |
return [{"title": "Rate Limit Exceeded", "url": "", "snippet": "Too many requests to the news API. Please try again later."}] | |
return [{"title": "HTTP Error", "url": "", "snippet": f"Failed to fetch news due to HTTP error: {e}"}] | |
except Exception: | |
return [{"title": "Error", "url": "", "snippet": "An unexpected error occurred while fetching news. Please try again later."}] | |
def get_stock_data(symbol): | |
"""Fetch stock data using Alpha Vantage API.""" | |
try: | |
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={ALPHA_VANTAGE_API_KEY}" | |
response = requests.get(url, timeout=10) | |
response.raise_for_status() | |
data = response.json().get("Global Quote", {}) | |
if not data: | |
return {"symbol": symbol, "error": "No data found for this symbol."} | |
return { | |
"symbol": symbol, | |
"price": data.get("05. price", "N/A"), | |
"change": data.get("09. change", "N/A"), | |
"change_percent": data.get("10. change percent", "N/A") | |
} | |
except ConnectionError: | |
return {"symbol": symbol, "error": "Failed to connect to the stock API. Please check your internet connection."} | |
except Timeout: | |
return {"symbol": symbol, "error": "Stock API request timed out. Please try again later."} | |
except HTTPError as e: | |
if response.status_code == 429: | |
return {"symbol": symbol, "error": "Too many requests to the stock API. Please try again later."} | |
return {"symbol": symbol, "error": f"Failed to fetch stock data due to HTTP error: {e}"} | |
except Exception: | |
return {"symbol": symbol, "error": "An unexpected error occurred while fetching stock data. Please try again later."} | |
def determine_question_type(query): | |
"""Determine the type of user query using Mistral LLM via CrewAI's task mechanism.""" | |
classifier_agent = Agent( | |
role="Query Classifier", | |
goal="Classify user queries into appropriate categories, including detecting out-of-scope queries.", | |
backstory="An expert in natural language understanding, capable of analyzing queries and categorizing them accurately.", | |
llm=mistral_llm, | |
verbose=True, | |
allow_delegation=False | |
) | |
# Check if the query is finance-related | |
finance_check_prompt = f""" | |
Analyze the following user query and determine if it is related to finance: | |
- Return 'Yes' if the query is related to financial terms, concepts, strategies, market news, or stock analysis (e.g., banking, stocks, revenue, P/E ratio). | |
- Return 'No' if the query is unrelated to finance (e.g., cooking recipes, weather, or unrelated topics). | |
Query: "{query}" | |
Provide your response in this format: | |
Is Finance Related: <Yes/No> | |
""" | |
finance_check_task = Task( | |
description=finance_check_prompt, | |
agent=classifier_agent, | |
expected_output="A classification in the format: Is Finance Related: <Yes/No>" | |
) | |
temp_crew = Crew( | |
agents=[classifier_agent], | |
tasks=[finance_check_task], | |
process=Process.sequential, | |
verbose=False | |
) | |
try: | |
response = temp_crew.kickoff() | |
response_text = response.raw if hasattr(response, 'raw') else str(response) | |
lines = response_text.strip().split("\n") | |
if len(lines) < 1 or "Is Finance Related:" not in lines[0]: | |
raise ValueError("Invalid response format from LLM for finance check") | |
is_finance_related = lines[0].replace("Is Finance Related: ", "").strip().lower() == "yes" | |
except Exception as e: | |
# Fallback to default behavior if classification fails | |
is_finance_related = False | |
if not is_finance_related: | |
return "out_of_scope", "This query is out of scope for a finance assistant." | |
#If finance-related, classify the query type | |
classification_prompt = f""" | |
Analyze the following user query and determine its category: | |
- finance_knowledge: General questions about financial terms, concepts, or strategies (e.g., 'What is revenue?', 'Explain P/E ratio') | |
- market_news: Questions about current market news, trends, or events (e.g., 'Latest news about cryptocurrency market') | |
- stock_analysis: Questions about specific stock analysis (e.g., mentioning a stock ticker like AAPL, 'Analyze META stock performance') | |
Query: "{query}" | |
Provide your response in this format: | |
Category: <category> | |
Extra Data: <additional info, such as the stock ticker for stock_analysis, or the query itself> | |
""" | |
classifier_task = Task( | |
description=classification_prompt, | |
agent=classifier_agent, | |
expected_output="A classification of the query in the format: Category: <category>\nExtra Data: <additional info>" | |
) | |
temp_crew = Crew( | |
agents=[classifier_agent], | |
tasks=[classifier_task], | |
process=Process.sequential, | |
verbose=False | |
) | |
try: | |
response = temp_crew.kickoff() | |
response_text = response.raw if hasattr(response, 'raw') else str(response) | |
lines = response_text.strip().split("\n") | |
if len(lines) < 2: | |
raise ValueError("Invalid response format from LLM for category classification") | |
category_line = lines[0].replace("Category: ", "").strip() | |
extra_data_line = lines[1].replace("Extra Data: ", "").strip() | |
if category_line not in ["finance_knowledge", "market_news", "stock_analysis"]: | |
raise ValueError(f"Invalid category: {category_line}") | |
return category_line, extra_data_line | |
except Exception: | |
return "finance_knowledge", query |