Spaces:
Sleeping
Sleeping
from fastapi import Request | |
import requests | |
from dotenv import load_dotenv | |
from utils.vector_store import get_vector_store | |
from pydantic import BaseModel | |
from fastapi import FastAPI, HTTPException, Request | |
import os | |
import sys | |
from utils.helpers.chat_mapper import map_answer_to_chat_response | |
from fastapi.middleware.cors import CORSMiddleware | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
load_dotenv() | |
app = FastAPI() | |
# Adding Middleware for testing | |
# # Get allowed origins from environment variable for flexibility | |
# ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "").split( | |
# ",") if os.environ.get("ALLOWED_ORIGINS") else ["*"] | |
# # For development/testing, you can also hardcode your Vercel domain | |
# VERCEL_DOMAINS = [ | |
# "https://your-app-name.vercel.app", # Replace with your actual Vercel app name | |
# # Git branch deployments | |
# "https://mes-chatbot-project-git-simple-w-markdown-mangobutlers-projects.vercel.app/", | |
# # Git branch | |
# "mes-chatbot-project-73znou68u-mangobutlers-projects.vercel.app", | |
# "http://localhost:3000", # For local frontend development | |
# "http://localhost:5173", # For Vite dev server | |
# "http://127.0.0.1:3000", # Alternative localhost | |
# ] | |
# # Combine environment origins with Vercel domains | |
# if ALLOWED_ORIGINS == ["*"]: | |
# # If no specific origins set, use Vercel domains + wildcard for testing | |
# final_origins = VERCEL_DOMAINS + ["*"] | |
# else: | |
# final_origins = ALLOWED_ORIGINS + VERCEL_DOMAINS | |
# # Adding the Middleware | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=final_origins, | |
# allow_credentials=True, | |
# allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
# allow_headers=["*"], | |
# expose_headers=["*"] | |
# ) | |
# Simplified CORS for debugging | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --------------------------- | |
# Vector store mapping for different domains | |
VECTOR_STORE_PATHS = { | |
"mes": "./vector_stores/mes_db", | |
"technical": "./vector_stores/tech_db", | |
"general": "./vector_stores/general_db", | |
"default": "./vector_stores/general_db", | |
} | |
class QueryRequest(BaseModel): | |
query: str | |
# --------------------------- | |
# Gemini API setup | |
# --------------------------- | |
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
if not GEMINI_API_KEY: | |
raise ValueError("GEMINI_API_KEY environment variable required") | |
GEMINI_API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent" | |
# --------------------------- | |
# Vector store loader | |
# --------------------------- | |
def load_vector_store_by_prefix(query: str): | |
lower_q = query.lower().strip() | |
for prefix, path in VECTOR_STORE_PATHS.items(): | |
if prefix != "default" and lower_q.startswith(f"{prefix}:"): | |
cleaned_query = lower_q[len(prefix) + 1:].strip() | |
return get_vector_store(persist_directory=path), cleaned_query, prefix | |
return get_vector_store(persist_directory=VECTOR_STORE_PATHS["default"]), query, "default" | |
def generate_answer_with_gemini(query: str, context_docs: list): | |
# Build context string | |
# knowledge_parts = [] | |
# for i, doc in enumerate(context_docs[:3], 1): | |
# knowledge_parts.append(f"Data Source {i}: {doc.page_content[:300]}") | |
# knowledge_base = "\n\n".join(knowledge_parts) | |
knowledge_parts = [] | |
for i, doc in enumerate(context_docs, 1): | |
knowledge_parts.append(f"Data Source {i}: {doc.page_content.strip()}") | |
knowledge_base = "\n\n".join(knowledge_parts) | |
# The updated prompt is more direct and forceful | |
prompt = ( | |
"You are an expert AI assistant that uses a provided knowledge base to answer questions. " | |
"Your responses must always be based on this knowledge base, which is the ultimate source of truth. " | |
"You will only use your internal knowledge to supplement the answer, never to contradict it. " | |
"If and only if the knowledge base contains absolutely nothing relevant to the user's question, " | |
"you will respond with a polite and concise statement saying you cannot answer the question from the information you have. " | |
"You must never answer 'I don't know' if there is any information in the knowledge base that is even tangentially related to the question. " | |
"Always try your best to construct a useful answer by synthesizing the provided information. " | |
"Do not refer to the 'knowledge base' or 'sources' directly; instead, use phrases like 'based on the information I have'.\n\n" | |
f"My knowledge base:\n{knowledge_base}\n\n" | |
f"User's Question: {query}\n\nAnswer:" | |
) | |
# print the prompt for debugging | |
print("π Prompt sent to Gemini API:", prompt) | |
try: | |
response = requests.post( | |
f"{GEMINI_API_URL}?key={GEMINI_API_KEY}", | |
json={ | |
"contents": [ | |
{ | |
"role": "user", | |
"parts": [ | |
{"text": prompt} | |
] | |
} | |
], | |
"generationConfig": { | |
"temperature": 0.7, | |
"maxOutputTokens": 300 | |
} | |
}, | |
timeout=300 | |
) | |
if response.status_code != 200: | |
return f"API Error: {response.status_code} - {response.text}" | |
data = response.json() | |
# Extract answer text | |
return ( | |
data.get("candidates", [{}])[0] | |
.get("content", {}) | |
.get("parts", [{}])[0] | |
.get("text", "") | |
.strip() | |
or "I couldn't generate an answer." | |
) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Middleware for logging requests | |
async def log_requests(request: Request, call_next): | |
print(f"π Request: {request.method} {request.url}") | |
print(f"π Headers: {dict(request.headers)}") | |
print(f"π Origin: {request.headers.get('origin', 'No Origin')}") | |
print( | |
f"π User-Agent: {request.headers.get('user-agent', 'No User-Agent')}") | |
response = await call_next(request) | |
print(f"π Response Status: {response.status_code}") | |
return response | |
# --------------------------- | |
# API Endpoints | |
# --------------------------- | |
def root(): | |
return { | |
"status": "running", | |
"model": "gemini-2.0-flash", | |
"using_direct_api": True, | |
"client_ready": True | |
} | |
# @app.post("/ask") | |
async def ask_question(request: Request): | |
try: | |
# Print raw incoming request body | |
raw_body = await request.body() | |
print("π₯ Incoming POST /ask request body:") | |
print(raw_body.decode("utf-8")) | |
# Parse into your Pydantic model | |
parsed_request = QueryRequest.model_validate_json(raw_body) | |
print("β Parsed request object:", parsed_request) | |
vector_store, cleaned_query, store_key = load_vector_store_by_prefix( | |
parsed_request.query | |
) | |
if not vector_store: | |
raise HTTPException( | |
status_code=500, detail="Vector store not ready" | |
) | |
retriever = vector_store.as_retriever( | |
search_type="mmr", | |
search_kwargs={ | |
"k": 6, | |
"fetch_k": 20, | |
"lambda_mult": 0.5 | |
} | |
) | |
docs = retriever.get_relevant_documents(cleaned_query) | |
# Deduplicate | |
seen = set() | |
unique_docs = [] | |
for doc in docs: | |
snippet = doc.page_content.strip() | |
if snippet not in seen: | |
seen.add(snippet) | |
unique_docs.append(doc) | |
docs = unique_docs[:5] | |
if not docs: | |
return { | |
"answer": "I couldn't find any relevant information in the knowledge base to answer your question.", | |
"model_used": "gemini-2.0-flash", | |
"vector_store_used": VECTOR_STORE_PATHS[store_key], | |
"sources": [] | |
} | |
answer = generate_answer_with_gemini(cleaned_query, docs) | |
answer_obj = { | |
"answer": answer, | |
"model_used": "gemini-2.0-flash", | |
"vector_store_used": VECTOR_STORE_PATHS[store_key], | |
"sources": [ | |
{ | |
"content": doc.page_content[:500] + "...\n", | |
"metadata": doc.metadata | |
} | |
for doc in docs | |
] | |
} | |
# For debugging, print the generated answer object | |
# print("Generated answer object:", | |
# map_answer_to_chat_response(answer_obj)) | |
return map_answer_to_chat_response(answer_obj) | |
except Exception as e: | |
print(f"Error in ask_question: {e}") | |
raise HTTPException(status_code=500, detail=f"Error: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |