Futuresony's picture
Update api.py
c632e1e verified
# Install necessary libraries if not already present
# These lines will be executed as shell commands by %%writefile
# !pip install duckduckgo_search dateparser
# Combined Imports (already present in LOR3w0_3wiYL, keeping for clarity)
import os
from huggingface_hub import InferenceClient
import torch
import re
import warnings
import time
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig # Not directly used in this API code, but kept for potential future use
from sentence_transformers import SentenceTransformer, util, CrossEncoder
import gspread
# from google.colab import auth # Not directly used in this API code
from google.auth import default # Not directly used in this API code
from tqdm import tqdm # Not directly used in this API code
from ddgs import DDGS # Corrected import based on user feedback
import spacy
from datetime import date, timedelta, datetime
from dateutil.relativedelta import relativedelta # Corrected typo
import traceback
import base64
import dateparser
from dateparser.search import search_dates
import pytz
# from google.colab import userdata # Removed Colab-specific import
from concurrent.futures import ThreadPoolExecutor, as_completed
# FastAPI Imports
from fastapi import FastAPI, Request, HTTPException, Depends, Security
from fastapi.security.api_key import APIKeyHeader
from dotenv import load_dotenv # For loading environment variables from a .env file
from fastapi.responses import JSONResponse # Import JSONResponse
# Load environment variables from .env file (if it exists)
load_dotenv()
# Suppress warnings (already present in LOR3w0_wiYL)
warnings.filterwarnings("ignore", category=UserWarning)
# Define global variables and load secrets from environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
SHEET_ID = os.getenv("SHEET_ID")
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS")
API_KEY = os.getenv("API_KEY") # Load API key from environment variables
# Add print statements to check if secrets are loaded (for debugging in logs)
print(f"HF_TOKEN loaded: {'*' * len(HF_TOKEN) if HF_TOKEN else 'None'}")
print(f"SHEET_ID loaded: {'*' * len(SHEET_ID) if SHEET_ID else 'None'}")
print(f"GOOGLE_BASE64_CREDENTIALS loaded: {'*' * len(GOOGLE_BASE64_CREDENTIALS) if GOOGLE_BASE64_CREDENTIALS else 'None'}")
print(f"API_KEY loaded: {'*' * len(API_KEY) if API_KEY else 'None'}")
# Global variables for component initialization status
llm_client_initialized = False
spacy_loaded = False
embedder_loaded = False
reranker_loaded = False
business_info_loaded = False
# Initialize InferenceClient (already present in LOR3w0_wiYL)
client = None
def initialize_llm_client():
"""Initializes the Hugging Face InferenceClient."""
global client, llm_client_initialized
llm_client_initialized = False
print("Attempting to initialize InferenceClient...")
if not HF_TOKEN:
print("Error: HF_TOKEN not loaded. InferenceClient cannot be initialized.")
return
try:
client = InferenceClient("google/gemma-2-9b-it", token=HF_TOKEN)
# Optional: Make a small test call to ensure the client is working
try:
test_response = client.chat_completion(messages=[{"role": "user", "content": "hello"}], max_tokens=10)
if test_response:
print("InferenceClient test call successful.")
llm_client_initialized = True
else:
print("InferenceClient test call failed.")
except Exception as test_e:
print(f"InferenceClient test call failed: {test_e}")
print(traceback.format_exc())
client = None # Reset client if test fails
if llm_client_initialized:
print("InferenceClient initialized.")
else:
print("InferenceClient initialization failed.")
except Exception as e:
print(f"Error initializing InferenceClient: {e}")
print(traceback.format_exc())
client = None # Set client to None if initialization fails
llm_client_initialized = False
# Load spacy model for sentence splitting (already present in LOR3w0_wiYL)
nlp = None
def load_spacy_model():
"""Loads the SpaCy model."""
global nlp, spacy_loaded
spacy_loaded = False
print("Attempting to load SpaCy model 'en_core_web_sm'...")
try:
# Load the model directly, assuming it's installed during Docker build
nlp = spacy.load("en_core_web_sm")
print("SpaCy model 'en_core_web_sm' loaded.")
spacy_loaded = True
except OSError:
print("SpaCy model 'en_core_web_sm' not found. Please ensure it is installed.")
print(traceback.format_exc()) # Print traceback for debugging
nlp = None # Set nlp to None if loading fails
spacy_loaded = False
except Exception as e:
print(f"Error loading SpaCy model: {e}")
print(traceback.format_exc())
nlp = None
spacy_loaded = False
# Load SentenceTransformer for RAG/business info retrieval and semantic detection (already present in LOR3w0_wiYL)
embedder = None
def load_embedder_model():
"""Loads the Sentence Transformer model."""
global embedder, embedder_loaded
embedder_loaded = False
print("Attempting to load Sentence Transformer (sentence-transformers/paraphrase-MiniLM-L6-v2)...")
try:
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
print("Sentence Transformer loaded.")
embedder_loaded = True
except Exception as e:
print(f"Error loading Sentence Transformer: {e}")
print(traceback.format_exc()) # Print traceback for debugging
embedder = None
embedder_loaded = False
# Load a Cross-Encoder model for re-ranking retrieved documents (already present in LOR3w0_wiYL)
reranker = None
def load_reranker_model():
"""Loads the Cross-Encoder model."""
global reranker, reranker_loaded
reranker_loaded = False
print("Attempting to load Cross-Encoder Reranker (cross-encoder/ms-marco-MiniLM-L6-v2)...")
try:
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
print("Cross-Encoder Reranker loaded.")
reranker_loaded = True
except Exception as e:
print(f"Error loading Cross-Encoder Reranker: {e}")
print("Please ensure the model identifier 'cross-encoder/ms-marco-MiniLM-L6-v2' is correct and accessible on Hugging Face Hub.")
print(traceback.format_exc())
reranker = None
reranker_loaded = False
# Google Sheets Authentication (already present in LOR3w0_wiYL)
gc = None # Global variable for gspread client
def authenticate_google_sheets():
"""Authenticates with Google Sheets using base64 encoded credentials."""
global gc
print("Authenticating Google Account...")
if not GOOGLE_BASE64_CREDENTIALS:
print("Error: GOOGLE_BASE64_CREDENTIALS secret not found. Skipping Google Sheets authentication.")
return False
try:
credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8')
credentials = json.loads(credentials_json)
gc = gspread.service_account_from_dict(credentials)
print("Google Sheets authentication successful via service account.")
return True
except Exception as e:
print(f"Google Sheets authentication failed: {e}")
print(traceback.format_exc())
return False
# Google Sheets Data Loading and Embedding (already present in LOR3w0_wiYL)
data = [] # Global variable to store loaded data
descriptions_for_embedding = []
embeddings = torch.tensor([])
# business_info_available is now managed by the load_business_info function
def load_business_info():
"""Loads business information from Google Sheet and creates embeddings."""
global data, descriptions_for_embedding, embeddings, business_info_loaded
business_info_loaded = False # Reset flag
print("Attempting to load business information from Google Sheet...")
if gc is None:
print("Skipping Google Sheet loading: Google Sheets client not authenticated.")
return
if not SHEET_ID:
print("Error: SHEET_ID not set. Skipping Google Sheet loading.")
return
try:
sheet = gc.open_by_key(SHEET_ID).sheet1
print(f"Successfully opened Google Sheet with ID: {SHEET_ID}")
data_records = sheet.get_all_records()
if not data_records:
print(f"Warning: No data records found in Google Sheet with ID: {SHEET_ID}")
data = []
descriptions_for_embedding = []
else:
filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')]
if not filtered_data:
print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.")
data = []
descriptions_for_embedding = []
else:
data = filtered_data
descriptions_for_embedding = [f"Service: {row['Service']}. Description: {row['Description']}" for row in data]
if descriptions_for_embedding and embedder is not None:
print("Encoding descriptions...")
try:
embeddings = embedder.encode(descriptions_for_embedding, convert_to_tensor=True)
print("Encoding complete.")
business_info_loaded = True
except Exception as e:
print(f"Error during description encoding: {e}")
embeddings = torch.tensor([])
business_info_loaded = False
else:
print("Skipping encoding descriptions: No descriptions found or embedder not available.")
embeddings = torch.tensor([])
business_info_loaded = False
print(f"Loaded {len(descriptions_for_embedding)} entries from Google Sheet for embedding/RAG.")
if not business_info_loaded:
print("Business information retrieval (RAG) is NOT available.")
else:
print("Business information retrieval (RAG) is available.")
except gspread.exceptions.SpreadsheetNotFound:
print(f"Error: Google Sheet with ID '{SHEET_ID}' not found.")
print("Please check the SHEET_ID and ensure your authenticated Google Account has access to this sheet.")
business_info_loaded = False
except Exception as e:
print(f"An error occurred while accessing the Google Sheet: {e}")
print(traceback.format_exc())
business_info_loaded = False
# Business Info Retrieval (RAG) (already present in LOR3w0_wiYL)
def retrieve_business_info(query: str, top_n: int = 3) -> list:
"""
Retrieves relevant business information from loaded data based on a query.
"""
global data
if not business_info_loaded or embedder is None or not descriptions_for_embedding or not data:
print("Business information retrieval is not available or data is empty.")
return []
try:
query_embedding = embedder.encode(query, convert_to_tensor=True)
cosine_scores = util.cos_sim(query_embedding, embeddings)[0]
top_results_indices = torch.topk(cosine_scores, k=min(top_n, len(data)))[1].tolist()
top_results = [data[i] for i in top_results_indices]
if reranker is not None and top_results:
print("Re-ranking top results...")
rerank_pairs = [(query, descriptions_for_embedding[i]) for i in top_results_indices]
rerank_scores = reranker.predict(rerank_pairs)
reranked_indices = sorted(range(len(rerank_scores)), key=lambda i: rerank_scores[i], reverse=True)
reranked_results = [top_results[i] for i in reranked_indices]
print("Re-ranking complete.")
return reranked_results
else:
return top_results
except Exception as e:
print(f"Error during business information retrieval: {e}")
print(traceback.format_exc())
return []
# Function to perform DuckDuckGo Search and return results with URLs (already present in LOR3w0_wiYL)
def perform_duckduckgo_search(query: str, max_results: int = 5):
"""
Performs a search using DuckDuckGo and returns a list of dictionaries.
Includes a delay to avoid rate limits.
Returns an empty list and prints an error if search fails.
"""
print(f"Executing Tool: perform_duckduckgo_search with query='{query}')")
search_results_list = []
try:
time.sleep(1)
with DDGS() as ddgs:
search_query = query.strip()
if not search_query or len(search_query.split()) < 2:
print(f"Skipping search for short query: '{search_query}'")
return []
print(f"Sending search query to DuckDuckGo: '{search_query}'")
results_generator = ddgs.text(search_query, max_results=max_results)
results_found = False
for r in results_generator:
search_results_list.append(r)
results_found = True
print(f"Raw results from DuckDuckGo: {search_results_list}")
if not results_found and max_results > 0:
print(f"DuckDuckGo search for '{search_query}' returned no results.")
elif results_found:
print(f"DuckDuckGo search for '{search_query}' completed. Found {len(search_results_list)} results.")
except Exception as e:
print(f"Error during Duckduckgo search for '{search_query if 'search_query' in locals() else query}': {e}")
print(traceback.format_exc())
return []
return search_results_list
# Define the new semantic date/time detection and calculation function using dateparser (already present in LOR3w0_wiYL)
def perform_date_calculation(query: str) -> str or None:
"""
Analyzes query for date/time information using dateparser.
If dateparser finds a date, it returns a human-friendly response string.
Otherwise, it returns None.
It is designed to handle multiple languages and provide the time for East Africa (Tanzania).
"""
print(f"Executing Tool: perform_date_calculation with query='{query}') using dateparser.search_dates")
try:
eafrica_tz = pytz.timezone('Africa/Dar_es_Salaam')
now = datetime.now(eafrica_tz)
except pytz.UnknownTimeZoneError:
print("Error: Unknown timezone 'Africa/Dar_es_Salaam'. Using default system time.")
now = datetime.now()
try:
found = search_dates(
query,
settings={
"PREFER_DATES_FROM": "future",
"RELATIVE_BASE": now
},
languages=['sw', 'en'] # Prioritize Swahili
)
if not found:
print("dateparser.search_dates could not parse any date/time.")
return None
text_snippet, parsed = found[0]
print(f"dateparser.search_dates found: text='{text_snippet}', parsed='{parsed}')")
is_swahili = any(swahili_phrase in query.lower() for swahili_phrase in ['tarehe', 'siku', 'saa', 'muda', 'leo', 'kesho', 'jana', 'ngapi', 'gani', 'mwezi', 'mwaka'])
if now.tzinfo is not None and parsed.tzinfo is None:
parsed = now.tzinfo.localize(parsed)
elif now.tzinfo is None and parsed.tzinfo is not None:
parsed = parsed.replace(tzinfo=None)
if parsed.date() == now.date():
if abs((parsed - now).total_seconds()) < 60 or parsed.time() == datetime.min.time():
print("Query parsed to today's date and time is close to 'now' or midnight, returning current time/date.")
if is_swahili:
return f"Kwa saa za Afrika Mashariki (Tanzania), tarehe ya leo ni {now.strftime('%A, %d %B %Y')} na saa ni {now.strftime('%H:%M:%S')}."
else:
return f"In East Africa (Tanzania), the current date is {now.strftime('%A, %d %B %Y')} and the time is {now.strftime('%H:%M:%S')}."
else:
print(f"Query parsed to a specific time today: {parsed.strftime('%H:%M:%S')}")
if is_swahili:
return f"Hiyo inafanyika leo, {parsed.strftime('%A, %d %B %Y')}, saa {parsed.strftime('%H:%M:%S')} saa za Afrika Mashariki."
else:
return f"That falls on today, {parsed.strftime('%A, %d %B %Y')}, at {parsed.strftime('%H:%M:%S')} East Africa Time."
else:
print(f"Query parsed to a specific date: {parsed.strftime('%A, %d %B %Y')} at {parsed.strftime('%H:%M:%S')}")
time_str = parsed.strftime('%H:%M:%S')
date_str = parsed.strftime('%A, %d %B %Y')
if parsed.tzinfo:
tz_name = parsed.tzinfo.tzname(parsed) or 'UTC'
if is_swahili:
return f"Hiyo inafanyika tarehe {date_str} saa {time_str} {tz_name}."
else:
return f"That falls on {date_str} at {time_str} {tz_name}."
else:
if is_swahili:
return f"Hiyo inafanyika tarehe {date_str} saa {time_str}."
else:
return f"That falls on {date_str} at {time_str}."
except Exception as e:
print(f"Error during dateparser.search_dates execution: {e}")
print(traceback.format_exc())
return f"An error occurred while parsing date/time: {e}"
# Function to determine if a query requires a tool or can be answered directly (already present in LOR3w0_wiYL)
def determine_tool_usage(query: str) -> str:
"""
Analyzes the query to determine if a specific tool is needed.
Returns the name of the tool ('duckduckgo_search', 'business_info_retrieval',
'date_calculation') or 'none' if no specific tool is clearly indicated.
Prioritizes business information retrieval, then specific tools based on keywords
and LLM judgment.
"""
query_lower = query.lower()
if business_info_loaded: # Check if business info is loaded before attempting LLM check
messages_business_check = [{"role": "user", "content": f"Does the following query ask about a specific person, service, offering, or description that is likely to be found *only* within a specific business's internal knowledge base, and not general knowledge? For example, questions about 'Salum' or 'Jackson Kisanga' are likely business-related, while questions about 'the current president of the USA' or 'who won the Ballon d'Or' are general knowledge. Answer only 'yes' or 'no'. Query: {query}"}]
try:
business_check_response = client.chat_completion(
messages=messages_business_check,
max_tokens=10,
temperature=0.1
).choices[0].message.content.strip().lower()
if business_check_response == "yes":
print(f"Detected as specific business info query based on LLM check: '{query}'")
return "business_info_retrieval"
else:
print(f"LLM check indicates not a specific business info query: '{query}'")
except Exception as e:
print(f"Error during LLM call for business info check for query '{query}': {e}")
print(traceback.format_exc())
print(f"Proceeding without business info check for query '{query}' due to error.")
else:
print("Skipping LLM business info check: Business information not loaded.")
date_time_check_result = perform_date_calculation(query)
if date_time_check_result is not None:
print(f"Detected as date/time calculation query based on dateparser result for: '{query}'")
return "date_calculation"
messages_tool_determination_search = [{"role": "user", "content": f"Does the following query require searching the web for current or general knowledge information (e.g., news, facts, definitions, current events)? Respond ONLY with 'duckduckgo_search' or 'none'. Query: {query}"}]
try:
search_determination_response = client.chat_completion(
messages=messages_tool_determination_search,
max_tokens=20,
temperature=0.1,
top_p=0.9
).choices[0].message.content or ""
response_lower = search_determination_response.strip().lower()
if "duckduckgo_search" in response_lower:
print(f"Model-determined tool for '{query}': 'duckduckgo_search'")
return "duckduckgo_search"
else:
print(f"Model-determined tool for '{query}': 'none' (for search)")
except Exception as e:
print(f"Error during LLM call for search tool determination for query '{query}': {e}")
print(traceback.format_exc())
print(f"Proceeding without search tool check for query '{query}' due to error.")
print(f"No specific tool determined for '{query}'. Defaulting to 'none'.")
return "none"
# Function to generate text using the LLM, incorporating tool results if available (already present in LOR3w0_wiYL)
def generate_text(prompt: str, tool_results: dict = None) -> str:
"""
Generates text using the configured LLM, optionally incorporating tool results.
"""
if not llm_client_initialized or client is None:
print("LLM client is not initialized. Cannot generate text.")
return "Error: The language model is not available at this time."
full_prompt_builder = [prompt]
if tool_results and any(tool_results.values()):
full_prompt_builder.append("\n\nTool Results:\n")
for question, results in tool_results.items():
if results:
full_prompt_builder.append(f"--- Results for: {question} ---\n")
if isinstance(results, list):
for i, result in enumerate(results):
if isinstance(result, dict) and 'Service' in result and 'Description' in result:
full_prompt_builder.append(f"Business Info {i+1}:\nService: {result.get('Service', 'N/A')}\nDescription: {result.get('Description', 'N/A')}\n\n")
elif isinstance(result, dict) and 'url' in result:
full_prompt_builder.append(f"Search Result {i+1}:\nTitle: {result.get('title', 'N/A')}\nURL: {result.get('url', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\n\n")
else:
full_prompt_builder.append(f"{result}\n\n")
elif isinstance(results, dict):
for key, value in results.items():
full_prompt_builder.append(f"{key}: {value}\n")
full_prompt_builder.append("\n")
else:
full_prompt_builder.append(f"{results}\n\n")
full_prompt_builder.append("Based on the provided tool results, answer the user's original query. If a question was answered by a tool, use the tool's result directly in your response.")
print("Added tool results and instruction to final prompt.")
else:
print("No tool results to add to final prompt.")
full_prompt = "".join(full_prompt_builder)
print(f"Sending prompt to LLM:\n---\n{full_prompt}\n---")
generation_config = {
"temperature": 0.7,
"max_new_tokens": 500,
"top_p": 0.95,
"top_k": 50,
"do_sample": True,
}
try:
response = client.chat_completion(
messages=[
{"role": "user", "content": full_prompt}
],
max_tokens=generation_config.get("max_new_tokens", 512),
temperature=generation_config.get("temperature", 0.7),
top_p=generation_config.get("top_p", 0.95)
).choices[0].message.content or ""
print("LLM generation successful using chat_completion.")
return response
except Exception as e:
print(f"Error during final LLM generation: {e}")
print(traceback.format_exc())
return "An error occurred while generating the final response."
# Refactored core chat logic into a function
def process_query_with_tools(query: str):
"""
Processes user queries by breaking down multi-part queries, determining and
executing appropriate tools for each question, and synthesizing results
using the LLM. Prioritizes business information retrieval.
This function is designed to be called by the API endpoint.
"""
print(f"Processing query with tools: {query}")
# Ensure LLM client is initialized before proceeding with any LLM calls
if not llm_client_initialized or client is None:
print("LLM client not initialized. Cannot process query.")
return "Error: The language model is not available. Please try again later."
print("\n--- Breaking down query ---")
prompt_for_question_breakdown = f"""
Analyze the following query and list each distinct question found within it.
Present each question on a new line, starting with a hyphen.
Query: {query}
"""
try:
messages_question_breakdown = [{"role": "user", "content": prompt_for_question_breakdown}]
question_breakdown_response = client.chat_completion(
messages=messages_question_breakdown,
max_tokens=100,
temperature=0.1,
top_p=0.9
).choices[0].message.content or ""
individual_questions = [line.strip() for line in question_breakdown_response.split('\n') if line.strip()]
cleaned_questions = [re.sub(r'^[-*]?\s*', '', q) for q in individual_questions]
print("Individual questions identified:")
for q in cleaned_questions:
print(f"- {q}")
except Exception as e:
print(f"Error during LLM call for question breakdown: {e}")
print(traceback.format_exc())
cleaned_questions = [query] # Fallback to treating the whole query as one question
print("\n--- Determining tools per question ---")
determined_tools = {}
for question in cleaned_questions:
print(f"\nAnalyzing question for tool determination: '{question}'")
determined_tools[question] = determine_tool_usage(question)
print(f"Determined tool for '{question}': '{determined_tools[question]}'") # Corrected print statement
print("\nSummary of determined tools per question:")
for question, tool in determined_tools.items():
print(f"'{question}': '{tool}'")
print("\n--- Executing tools and collecting results ---")
tool_results = {}
for question, tool in determined_tools.items():
print(f"\nExecuting tool '{tool}' for question: '{question}'")
result = None
if tool == "date_calculation":
result = perform_date_calculation(question)
elif tool == "duckduckgo_search":
result = perform_duckduckgo_search(question)
elif tool == "business_info_retrieval":
result = retrieve_business_info(question)
elif tool == "none":
print(f"Skipping tool execution for question: '{question}' as tool is 'none'. LLM will handle.")
result = None
if result is not None:
tool_results[question] = result
print("\n--- Collected Tool Results ---")
if tool_results:
for question, result in tool_results.items():
print(f"\nQuestion: {question}")
print(f"Result: {result}")
else:
print("No tool results were collected.")
print("\n--------------------------")
print("\n--- Generating final response ---")
final_response = generate_text(query, tool_results)
print("\n--- Final Response from LLM ---")
print(final_response)
print("\n----------------------------")
return final_response
# --- FastAPI Application Setup ---
app = FastAPI()
# Define the APIKeyHeader instance correctly
api_key_header = APIKeyHeader(name="x-api-key", auto_error=True)
# API Key Authentication Dependency
def get_api_key(api_key_header_value: str = Security(api_key_header)):
# Check if API_KEY is None before comparison
if API_KEY is None or api_key_header_value == API_KEY:
return api_key_header_value
else:
raise HTTPException(status_code=403, detail="Could not validate credentials")
# API Endpoint
@app.post("/chat/")
async def chat_endpoint(request: Request, api_key: str = Depends(get_api_key)):
"""
API endpoint to process user chat queries using the LLM and tools.
Requires API key authentication in the 'x-api-key' header.
"""
try:
body = await request.json()
query = body.get("query")
if not query:
raise HTTPException(status_code=400, detail="Query parameter is required.")
# Ensure client is initialized before processing query
if not llm_client_initialized or client is None:
raise HTTPException(status_code=503, detail="LLM client not initialized. Please wait or check logs.")
response = process_query_with_tools(query)
return {"response": response}
except Exception as e:
print(f"Error in chat_endpoint: {e}")
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
# Health Check Endpoint
@app.get("/health")
async def health_check():
"""
Health check endpoint to verify the application is running and essential components are loaded.
Returns 200 OK if all critical components are loaded, 503 Service Unavailable otherwise.
"""
status = {
"status": "unhealthy",
"llm_client_initialized": llm_client_initialized,
"business_info_loaded": business_info_loaded,
"spacy_loaded": spacy_loaded,
"embedder_loaded": embedder_loaded,
"reranker_loaded": reranker_loaded,
"secrets_loaded": {
"HF_TOKEN": HF_TOKEN is not None,
"SHEET_ID": SHEET_ID is not None,
"GOOGLE_BASE64_CREDENTIALS": GOOGLE_BASE64_CREDENTIALS is not None,
"API_KEY": API_KEY is not None,
}
}
# Check if all critical components are loaded
all_critical_loaded = (
llm_client_initialized and
spacy_loaded and
embedder_loaded and
reranker_loaded and
(business_info_loaded if (SHEET_ID and GOOGLE_BASE64_CREDENTIALS) else True) # Business info is critical only if secrets are set
)
if all_critical_loaded:
status["status"] = "ok"
return JSONResponse(status_code=200, content=status)
else:
unhealthy_components = [key for key, value in status.items() if isinstance(value, bool) and not value]
if status["secrets_loaded"] and not all(status["secrets_loaded"].values()):
unhealthy_components.append("secrets_loaded (partial)")
status["unhealthy_components"] = unhealthy_components
return JSONResponse(status_code=503, content=status)
# Optional: Root endpoint for basic info
@app.get("/")
async def read_root():
"""
Root endpoint providing basic application information and status.
"""
status = {
"message": "LLM with Tools API is running",
"llm_client_initialized": llm_client_initialized,
"business_info_loaded": business_info_loaded,
"spacy_loaded": spacy_loaded,
"embedder_loaded": embedder_loaded,
"reranker_loaded": reranker_loaded,
"secrets_loaded": {
"HF_TOKEN": HF_TOKEN is not None,
"SHEET_ID": SHEET_ID is not None,
"GOOGLE_BASE64_CREDENTIALS": GOOGLE_BASE64_CREDENTIALS is not None,
"API_KEY": API_KEY is not None,
}
}
if not all(status["secrets_loaded"].values()):
status["warning"] = status.get("warning", "") + " Not all secrets are loaded."
if not status["llm_client_initialized"]:
status["warning"] = status.get("warning", "") + " LLM client not initialized."
if not status["business_info_loaded"] and (SHEET_ID and GOOGLE_BASE64_CREDENTIALS):
status["warning"] = status.get("warning", "") + " Business info (RAG) not loaded."
if not status["spacy_loaded"]:
status["warning"] = status.get("warning", "") + " SpaCy model not loaded."
if not status["embedder_loaded"]:
status["warning"] = status.get("warning", "") + " Embedder not loaded."
if not status["reranker_loaded"]:
status["warning"] = status.get("warning", "") + " Reranker not loaded."
return status
# Initialize components on startup
# This will run when the script is imported or executed directly
print("Starting component initialization...")
authenticate_google_sheets() # Authenticate first as it's needed for load_business_info
load_spacy_model()
load_embedder_model()
load_reranker_model()
load_business_info() # Load business info after authentication and embedder are ready
initialize_llm_client() # Initialize LLM client last as it might be the largest model
print("Component initialization sequence complete.")
# To run this FastAPI application in Colab for testing purposes,
# you can use uvicorn.run() in a separate cell or a script.
# For production deployment, you would typically use a proper ASGI server setup.
# Example of how to run in Colab (requires a separate cell or script):
# import uvicorn
# from api import app # Assuming this code is saved as api.py
# uvicorn.run(app, host="0.0.0.0", port=8000) # Or use a more secure host/port for production