christian
adding logging mw
3e0aec0
raw
history blame
9.2 kB
from fastapi import Request
import requests
from dotenv import load_dotenv
from utils.vector_store import get_vector_store
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Request
import os
import sys
from utils.helpers.chat_mapper import map_answer_to_chat_response
from fastapi.middleware.cors import CORSMiddleware
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
load_dotenv()
app = FastAPI()
# Adding Middleware for testing
# # Get allowed origins from environment variable for flexibility
# ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "").split(
# ",") if os.environ.get("ALLOWED_ORIGINS") else ["*"]
# # For development/testing, you can also hardcode your Vercel domain
# VERCEL_DOMAINS = [
# "https://your-app-name.vercel.app", # Replace with your actual Vercel app name
# # Git branch deployments
# "https://mes-chatbot-project-git-simple-w-markdown-mangobutlers-projects.vercel.app/",
# # Git branch
# "mes-chatbot-project-73znou68u-mangobutlers-projects.vercel.app",
# "http://localhost:3000", # For local frontend development
# "http://localhost:5173", # For Vite dev server
# "http://127.0.0.1:3000", # Alternative localhost
# ]
# # Combine environment origins with Vercel domains
# if ALLOWED_ORIGINS == ["*"]:
# # If no specific origins set, use Vercel domains + wildcard for testing
# final_origins = VERCEL_DOMAINS + ["*"]
# else:
# final_origins = ALLOWED_ORIGINS + VERCEL_DOMAINS
# # Adding the Middleware
# app.add_middleware(
# CORSMiddleware,
# allow_origins=final_origins,
# allow_credentials=True,
# allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
# allow_headers=["*"],
# expose_headers=["*"]
# )
# Simplified CORS for debugging
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------
# Vector store mapping for different domains
VECTOR_STORE_PATHS = {
"mes": "./vector_stores/mes_db",
"technical": "./vector_stores/tech_db",
"general": "./vector_stores/general_db",
"default": "./vector_stores/general_db",
}
class QueryRequest(BaseModel):
query: str
# ---------------------------
# Gemini API setup
# ---------------------------
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("GEMINI_API_KEY environment variable required")
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
# ---------------------------
# Vector store loader
# ---------------------------
def load_vector_store_by_prefix(query: str):
lower_q = query.lower().strip()
for prefix, path in VECTOR_STORE_PATHS.items():
if prefix != "default" and lower_q.startswith(f"{prefix}:"):
cleaned_query = lower_q[len(prefix) + 1:].strip()
return get_vector_store(persist_directory=path), cleaned_query, prefix
return get_vector_store(persist_directory=VECTOR_STORE_PATHS["default"]), query, "default"
def generate_answer_with_gemini(query: str, context_docs: list):
# Build context string
# knowledge_parts = []
# for i, doc in enumerate(context_docs[:3], 1):
# knowledge_parts.append(f"Data Source {i}: {doc.page_content[:300]}")
# knowledge_base = "\n\n".join(knowledge_parts)
knowledge_parts = []
for i, doc in enumerate(context_docs, 1):
knowledge_parts.append(f"Data Source {i}: {doc.page_content.strip()}")
knowledge_base = "\n\n".join(knowledge_parts)
# The updated prompt is more direct and forceful
prompt = (
"You are an expert AI assistant that uses a provided knowledge base to answer questions. "
"Your responses must always be based on this knowledge base, which is the ultimate source of truth. "
"You will only use your internal knowledge to supplement the answer, never to contradict it. "
"If and only if the knowledge base contains absolutely nothing relevant to the user's question, "
"you will respond with a polite and concise statement saying you cannot answer the question from the information you have. "
"You must never answer 'I don't know' if there is any information in the knowledge base that is even tangentially related to the question. "
"Always try your best to construct a useful answer by synthesizing the provided information. "
"Do not refer to the 'knowledge base' or 'sources' directly; instead, use phrases like 'based on the information I have'.\n\n"
f"My knowledge base:\n{knowledge_base}\n\n"
f"User's Question: {query}\n\nAnswer:"
)
# print the prompt for debugging
print("πŸ” Prompt sent to Gemini API:", prompt)
try:
response = requests.post(
f"{GEMINI_API_URL}?key={GEMINI_API_KEY}",
json={
"contents": [
{
"role": "user",
"parts": [
{"text": prompt}
]
}
],
"generationConfig": {
"temperature": 0.7,
"maxOutputTokens": 300
}
},
timeout=300
)
if response.status_code != 200:
return f"API Error: {response.status_code} - {response.text}"
data = response.json()
# Extract answer text
return (
data.get("candidates", [{}])[0]
.get("content", {})
.get("parts", [{}])[0]
.get("text", "")
.strip()
or "I couldn't generate an answer."
)
except Exception as e:
return f"Error: {str(e)}"
# Middleware for logging requests
@app.middleware("http")
async def log_requests(request: Request, call_next):
print(f"πŸ” Request: {request.method} {request.url}")
print(f"πŸ” Headers: {dict(request.headers)}")
print(f"πŸ” Origin: {request.headers.get('origin', 'No Origin')}")
print(
f"πŸ” User-Agent: {request.headers.get('user-agent', 'No User-Agent')}")
response = await call_next(request)
print(f"πŸ” Response Status: {response.status_code}")
return response
# ---------------------------
# API Endpoints
# ---------------------------
@app.get("/")
def root():
return {
"status": "running",
"model": "gemini-2.0-flash",
"using_direct_api": True,
"client_ready": True
}
# @app.post("/ask")
@app.post("/")
async def ask_question(request: Request):
try:
# Print raw incoming request body
raw_body = await request.body()
print("πŸ“₯ Incoming POST /ask request body:")
print(raw_body.decode("utf-8"))
# Parse into your Pydantic model
parsed_request = QueryRequest.model_validate_json(raw_body)
print("βœ… Parsed request object:", parsed_request)
vector_store, cleaned_query, store_key = load_vector_store_by_prefix(
parsed_request.query
)
if not vector_store:
raise HTTPException(
status_code=500, detail="Vector store not ready"
)
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={
"k": 6,
"fetch_k": 20,
"lambda_mult": 0.5
}
)
docs = retriever.get_relevant_documents(cleaned_query)
# Deduplicate
seen = set()
unique_docs = []
for doc in docs:
snippet = doc.page_content.strip()
if snippet not in seen:
seen.add(snippet)
unique_docs.append(doc)
docs = unique_docs[:5]
if not docs:
return {
"answer": "I couldn't find any relevant information in the knowledge base to answer your question.",
"model_used": "gemini-2.0-flash",
"vector_store_used": VECTOR_STORE_PATHS[store_key],
"sources": []
}
answer = generate_answer_with_gemini(cleaned_query, docs)
answer_obj = {
"answer": answer,
"model_used": "gemini-2.0-flash",
"vector_store_used": VECTOR_STORE_PATHS[store_key],
"sources": [
{
"content": doc.page_content[:500] + "...\n",
"metadata": doc.metadata
}
for doc in docs
]
}
# For debugging, print the generated answer object
# print("Generated answer object:",
# map_answer_to_chat_response(answer_obj))
return map_answer_to_chat_response(answer_obj)
except Exception as e:
print(f"Error in ask_question: {e}")
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)