Spaces:
Sleeping
Sleeping
Said Lfagrouche
Fix model artifacts copying in Dockerfile and add fallback functionality for missing models
c13c6ef
# api_mental_health.py | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from pydantic import BaseModel | |
import pandas as pd | |
import numpy as np | |
import joblib | |
import re | |
import nltk | |
from nltk.tokenize import word_tokenize | |
from nltk.stem import WordNetLemmatizer | |
from nltk.corpus import stopwords | |
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer | |
import chromadb | |
from chromadb.config import Settings | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_chroma import Chroma | |
from openai import OpenAI | |
import os | |
from dotenv import load_dotenv | |
from langsmith import Client, traceable | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.prompts import ChatPromptTemplate | |
import logging | |
from typing import List, Dict, Optional, Any, Union, Annotated | |
from datetime import datetime | |
from uuid import uuid4, UUID | |
import json | |
import requests | |
from fastapi.responses import StreamingResponse | |
from io import BytesIO | |
import base64 | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
# Set NLTK data path to a directory where we want to look for data | |
nltk_data_path = os.path.join(os.path.dirname(__file__), "nltk_data") | |
os.makedirs(nltk_data_path, exist_ok=True) | |
nltk.data.path.append(nltk_data_path) | |
# Skip downloading NLTK data to avoid permission errors | |
logger.info(f"Using NLTK data from {nltk_data_path} if available") | |
# Initialize FastAPI app | |
app = FastAPI(title="Mental Health Counselor API") | |
# Initialize global storage (to be replaced with proper database) | |
DATA_DIR = os.path.join(os.path.dirname(__file__), "data") | |
os.makedirs(DATA_DIR, exist_ok=True) | |
os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True) | |
os.makedirs(os.path.join(DATA_DIR, "sessions"), exist_ok=True) | |
os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True) | |
os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True) | |
# Initialize components | |
STOPWORDS = set(stopwords.words("english")) | |
lemmatizer = WordNetLemmatizer() | |
analyzer = SentimentIntensityAnalyzer() | |
output_dir = "mental_health_model_artifacts" | |
# Global variables for models and vector store | |
response_clf = None | |
crisis_clf = None | |
vectorizer = None | |
le = None | |
selector = None | |
lda = None | |
vector_store = None | |
llm = None | |
openai_client = None | |
langsmith_client = None | |
# Load models and initialize ChromaDB at startup | |
async def startup_event(): | |
global response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm, openai_client, langsmith_client | |
# Check environment variables | |
if not os.environ.get("OPENAI_API_KEY"): | |
logger.warning("OPENAI_API_KEY not set in .env file. Some functionality will be limited.") | |
if not os.environ.get("LANGCHAIN_API_KEY"): | |
logger.warning("LANGCHAIN_API_KEY not set in .env file. Some functionality will be limited.") | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = "MentalHealthCounselorPOC" | |
# Initialize LangSmith client if API key is available | |
try: | |
logger.info("Initializing LangSmith client") | |
langsmith_client = Client() | |
except Exception as e: | |
logger.warning(f"Failed to initialize LangSmith client: {e}") | |
langsmith_client = None | |
# Try to load saved components, continue with limited functionality if not available | |
logger.info("Loading model artifacts") | |
models_available = True | |
try: | |
response_clf = joblib.load(f"{output_dir}/response_type_classifier.pkl") | |
crisis_clf = joblib.load(f"{output_dir}/crisis_classifier.pkl") | |
vectorizer = joblib.load(f"{output_dir}/tfidf_vectorizer.pkl") | |
le = joblib.load(f"{output_dir}/label_encoder.pkl") | |
selector = joblib.load(f"{output_dir}/feature_selector.pkl") | |
try: | |
lda = joblib.load(f"{output_dir}/lda_model.pkl") | |
except Exception as lda_error: | |
logger.warning(f"Failed to load LDA model: {lda_error}. Creating placeholder model.") | |
from sklearn.decomposition import LatentDirichletAllocation | |
lda = LatentDirichletAllocation(n_components=10, random_state=42) | |
# Note: Placeholder is untrained; retrain for accurate results | |
except FileNotFoundError as e: | |
logger.warning(f"Missing model artifact: {e}. Running with limited functionality.") | |
models_available = False | |
# Set placeholder values for models to avoid errors | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.feature_selection import SelectKBest | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.decomposition import LatentDirichletAllocation | |
response_clf = RandomForestClassifier() | |
crisis_clf = RandomForestClassifier() | |
vectorizer = TfidfVectorizer() | |
le = LabelEncoder() | |
selector = SelectKBest() | |
lda = LatentDirichletAllocation(n_components=10) | |
# Initialize ChromaDB if possible | |
chroma_db_path = f"{output_dir}/chroma_db" | |
if not os.path.exists(chroma_db_path): | |
logger.warning(f"ChromaDB not found at {chroma_db_path}. Vector search will be unavailable.") | |
vector_store = None | |
else: | |
try: | |
logger.info("Initializing ChromaDB") | |
if os.environ.get("OPENAI_API_KEY"): | |
chroma_client = chromadb.PersistentClient( | |
path=chroma_db_path, | |
settings=Settings(anonymized_telemetry=False) | |
) | |
embeddings = OpenAIEmbeddings( | |
model="text-embedding-ada-002", | |
api_key=os.environ["OPENAI_API_KEY"], | |
disallowed_special=(), | |
chunk_size=1000 | |
) | |
global vector_store | |
try: | |
vector_store = Chroma( | |
client=chroma_client, | |
collection_name="mental_health_conversations", | |
embedding_function=embeddings | |
) | |
except Exception as chroma_error: | |
logger.warning(f"Error initializing Chroma collection: {chroma_error}") | |
vector_store = None | |
else: | |
logger.warning("Skipping ChromaDB initialization as OPENAI_API_KEY is not set") | |
vector_store = None | |
except Exception as e: | |
logger.warning(f"Error initializing ChromaDB: {e}") | |
vector_store = None | |
# Initialize OpenAI client and LLM if API key is available | |
logger.info("Initializing OpenAI client and LLM") | |
global openai_client, llm | |
if os.environ.get("OPENAI_API_KEY"): | |
try: | |
openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
llm = ChatOpenAI( | |
model="gpt-4o-mini", | |
temperature=0.7, | |
api_key=os.environ["OPENAI_API_KEY"] | |
) | |
except Exception as e: | |
logger.warning(f"Error initializing OpenAI client: {e}") | |
openai_client = None | |
llm = None | |
else: | |
logger.warning("OpenAI client not initialized as OPENAI_API_KEY is not set") | |
openai_client = None | |
llm = None | |
# Add route to check model availability | |
async def model_status(): | |
return { | |
"models_available": models_available, | |
"vector_store_available": vector_store is not None, | |
"llm_available": llm is not None, | |
"openai_api_key_set": os.environ.get("OPENAI_API_KEY") is not None, | |
"langchain_api_key_set": os.environ.get("LANGCHAIN_API_KEY") is not None | |
} | |
# Pydantic model for request | |
class PatientContext(BaseModel): | |
context: str | |
# New Pydantic models for expanded API functionality | |
class UserProfile(BaseModel): | |
user_id: Optional[str] = None | |
username: str | |
name: str | |
role: str = "counselor" | |
specializations: List[str] = [] | |
years_experience: Optional[int] = None | |
custom_crisis_keywords: List[str] = [] | |
preferences: Dict[str, Any] = {} | |
created_at: Optional[datetime] = None | |
updated_at: Optional[datetime] = None | |
class SessionData(BaseModel): | |
session_id: Optional[str] = None | |
counselor_id: str | |
patient_identifier: str # Anonymized ID | |
session_notes: str = "" | |
session_preferences: Dict[str, Any] = {} | |
crisis_keywords: List[str] = [] | |
created_at: Optional[datetime] = None | |
updated_at: Optional[datetime] = None | |
class ConversationEntry(BaseModel): | |
session_id: str | |
message: str | |
sender: str # 'patient' or 'counselor' | |
timestamp: Optional[datetime] = None | |
suggested_response: Optional[str] = None | |
response_type: Optional[str] = None | |
crisis_flag: bool = False | |
risk_level: Optional[str] = None | |
class FeedbackData(BaseModel): | |
suggestion_id: str | |
counselor_id: str | |
rating: int # 1-5 scale | |
was_effective: bool | |
comments: Optional[str] = None | |
class AnalysisRequest(BaseModel): | |
text: str | |
patient_background: Optional[Dict[str, Any]] = None | |
patient_age: Optional[int] = None | |
cultural_context: Optional[str] = None | |
class MultiModalInput(BaseModel): | |
session_id: str | |
counselor_id: str | |
input_type: str # 'text', 'audio', 'video' | |
content: str # Text content or file path/url | |
metadata: Dict[str, Any] = {} | |
class InterventionRequest(BaseModel): | |
patient_issue: str | |
patient_background: Optional[Dict[str, Any]] = None | |
intervention_type: Optional[str] = None # e.g., 'CBT', 'DBT', 'mindfulness' | |
# Text preprocessing function | |
def clean_text(text): | |
if pd.isna(text): | |
return "" | |
text = str(text).lower() | |
text = re.sub(r"[^a-zA-Z']", " ", text) | |
tokens = word_tokenize(text) | |
tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2] | |
return " ".join(tokens) | |
# Feature engineering function | |
def engineer_features(context, response=""): | |
context_clean = clean_text(context) | |
context_len = len(context_clean.split()) | |
context_vader = analyzer.polarity_scores(context)['compound'] | |
context_questions = context.count('?') | |
crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone'] | |
context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower()) | |
context_tfidf = vectorizer.transform([context_clean]).toarray() | |
tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])] | |
response_tfidf = np.zeros_like(context_tfidf) | |
lda_topics = lda.transform(context_tfidf) | |
feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \ | |
[f"topic_{i}" for i in range(10)] + tfidf_cols + \ | |
[f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])] | |
features = pd.DataFrame({ | |
"context_len": [context_len], | |
"context_vader": [context_vader], | |
"context_questions": [context_questions], | |
**{f"topic_{i}": [lda_topics[0][i]] for i in range(10)}, | |
**{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])}, | |
**{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])}, | |
}) | |
crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]] | |
crisis_flag = crisis_clf.predict(crisis_features)[0] | |
if context_crisis_score > 0: | |
crisis_flag = 1 | |
features["crisis_flag"] = crisis_flag | |
return features, feature_cols | |
# Prediction function | |
def predict_response_type(context): | |
if response_clf is None or vectorizer is None or le is None or selector is None or lda is None: | |
logger.warning("Models not available, returning dummy prediction") | |
return { | |
"response_type": "Empathetic Listening", | |
"crisis_flag": False, | |
"confidence": 0.5, | |
"features": {}, | |
"models_available": False | |
} | |
features, feature_cols = engineer_features(context) | |
selected_features = selector.transform(features[feature_cols]) | |
pred_encoded = response_clf.predict(selected_features)[0] | |
pred_label = le.inverse_transform([pred_encoded])[0] | |
confidence = response_clf.predict_proba(selected_features)[0].max() | |
if "?" in context and context.count("?") > 0: | |
pred_label = "Question" | |
if "trying" in context.lower() and "hard" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]): | |
pred_label = "Validation" | |
if "trying" in context.lower() and "positive" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]): | |
pred_label = "Question" | |
crisis_flag = bool(features["crisis_flag"].iloc[0]) | |
return { | |
"response_type": pred_label, | |
"crisis_flag": crisis_flag, | |
"confidence": confidence, | |
"features": features.to_dict(), | |
"models_available": True | |
} | |
# RAG suggestion function | |
def generate_suggestion_rag(context, response_type, crisis_flag): | |
results = vector_store.similarity_search_with_score(context, k=3) | |
retrieved_contexts = [ | |
f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})" | |
for res in results | |
] | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are an expert mental health counseling assistant. A counselor has provided the following patient situation: | |
Patient Situation: {context} | |
Predicted Response Type: {response_type} | |
Crisis Flag: {crisis_flag} | |
Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level. | |
For reference, here are similar cases from past conversations: | |
{retrieved_contexts} | |
Guidelines: | |
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988). | |
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions. | |
- For 'Advice', provide practical, actionable suggestions. | |
- For 'Question', pose an open-ended question to encourage further discussion. | |
- For 'Validation', affirm the patient's efforts or feelings. | |
Output in the following format: | |
```json | |
{{ | |
"suggested_response": "Your suggested response here", | |
"risk_level": "Low/Moderate/High" | |
}} | |
``` | |
""" | |
) | |
rag_chain = ( | |
{ | |
"context": RunnablePassthrough(), | |
"response_type": lambda x: response_type, | |
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis", | |
"retrieved_contexts": lambda x: "\n".join(retrieved_contexts) | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = rag_chain.invoke(context) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error generating RAG suggestion: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating RAG suggestion: {str(e)}") | |
# Direct suggestion function | |
def generate_suggestion_direct(context, response_type, crisis_flag): | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are an expert mental health counseling assistant. A counselor has provided the following patient situation: | |
Patient Situation: {context} | |
Predicted Response Type: {response_type} | |
Crisis Flag: {crisis_flag} | |
Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level. | |
Guidelines: | |
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988). | |
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions. | |
- For 'Advice', provide practical, actionable suggestions. | |
- For 'Question', pose an open-ended question to encourage further discussion. | |
- For 'Validation', affirm the patient's efforts or feelings. | |
- Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice. | |
Output in the following format: | |
```json | |
{{ | |
"suggested_response": "Your suggested response here", | |
"risk_level": "Low/Moderate/High" | |
}} | |
``` | |
""" | |
) | |
direct_chain = ( | |
{ | |
"context": RunnablePassthrough(), | |
"response_type": lambda x: response_type, | |
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis" | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = direct_chain.invoke(context) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error generating direct suggestion: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating direct suggestion: {str(e)}") | |
# User Profile Endpoints | |
async def create_user(profile: UserProfile): | |
"""Create a new counselor profile with preferences and specializations.""" | |
try: | |
saved_profile = save_user_profile(profile) | |
logger.info(f"Created user profile: {saved_profile.user_id}") | |
return saved_profile | |
except Exception as e: | |
logger.error(f"Error creating user profile: {e}") | |
raise HTTPException(status_code=500, detail=f"Error creating user profile: {str(e)}") | |
async def get_user(user_id: str): | |
"""Get a counselor profile by user ID.""" | |
profile = get_user_profile(user_id) | |
if not profile: | |
raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}") | |
return profile | |
async def update_user(user_id: str, profile_update: UserProfile): | |
"""Update a counselor profile.""" | |
existing_profile = get_user_profile(user_id) | |
if not existing_profile: | |
raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}") | |
# Preserve the original user_id | |
profile_update.user_id = user_id | |
# Preserve the original created_at timestamp | |
profile_update.created_at = existing_profile.created_at | |
try: | |
updated_profile = save_user_profile(profile_update) | |
logger.info(f"Updated user profile: {user_id}") | |
return updated_profile | |
except Exception as e: | |
logger.error(f"Error updating user profile: {e}") | |
raise HTTPException(status_code=500, detail=f"Error updating user profile: {str(e)}") | |
# Session Management Endpoints | |
async def create_session(session_data: SessionData): | |
"""Create a new session with patient identifier (anonymized).""" | |
try: | |
# Verify counselor exists | |
counselor = get_user_profile(session_data.counselor_id) | |
if not counselor: | |
raise HTTPException(status_code=404, detail=f"Counselor not found: {session_data.counselor_id}") | |
# If counselor has custom crisis keywords, add them to the session | |
if counselor.custom_crisis_keywords: | |
session_data.crisis_keywords.extend(counselor.custom_crisis_keywords) | |
saved_session = save_session(session_data) | |
logger.info(f"Created session: {saved_session.session_id}") | |
return saved_session | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error creating session: {e}") | |
raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}") | |
async def get_session_by_id(session_id: str): | |
"""Get a session by ID.""" | |
session = get_session(session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") | |
return session | |
async def get_counselor_sessions(counselor_id: str): | |
"""Get all sessions for a counselor.""" | |
sessions = get_user_sessions(counselor_id) | |
return sessions | |
async def update_session(session_id: str, session_update: SessionData): | |
"""Update a session.""" | |
existing_session = get_session(session_id) | |
if not existing_session: | |
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") | |
# Preserve the original session_id and created_at | |
session_update.session_id = session_id | |
session_update.created_at = existing_session.created_at | |
try: | |
updated_session = save_session(session_update) | |
logger.info(f"Updated session: {session_id}") | |
return updated_session | |
except Exception as e: | |
logger.error(f"Error updating session: {e}") | |
raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}") | |
# Conversation History Endpoints | |
async def add_conversation_entry(entry: ConversationEntry): | |
"""Add a new entry to a conversation.""" | |
try: | |
# Verify session exists | |
session = get_session(entry.session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail=f"Session not found: {entry.session_id}") | |
entry_id = save_conversation_entry(entry) | |
logger.info(f"Added conversation entry: {entry_id}") | |
return entry_id | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error adding conversation entry: {e}") | |
raise HTTPException(status_code=500, detail=f"Error adding conversation entry: {str(e)}") | |
async def get_conversation(session_id: str): | |
"""Get conversation history for a session.""" | |
try: | |
# Verify session exists | |
session = get_session(session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") | |
entries = get_conversation_history(session_id) | |
return entries | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error retrieving conversation history: {e}") | |
raise HTTPException(status_code=500, detail=f"Error retrieving conversation history: {str(e)}") | |
# API Endpoints | |
async def get_suggestion(context: PatientContext): | |
logger.info(f"Received suggestion request for context: {context.context}") | |
prediction = predict_response_type(context.context) | |
suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"]) | |
suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"]) | |
return { | |
"context": context.context, | |
"response_type": prediction["response_type"], | |
"crisis_flag": prediction["crisis_flag"], | |
"confidence": prediction["confidence"], | |
"rag_suggestion": suggestion_rag["suggested_response"], | |
"rag_risk_level": suggestion_rag["risk_level"], | |
"direct_suggestion": suggestion_direct["suggested_response"], | |
"direct_risk_level": suggestion_direct["risk_level"] | |
} | |
async def get_session_suggestion(request: dict): | |
"""Get suggestion within a session context, with enhanced crisis detection based on session keywords.""" | |
try: | |
session_id = request.get("session_id") | |
if not session_id: | |
raise HTTPException(status_code=400, detail="session_id is required") | |
context = request.get("context") | |
if not context: | |
raise HTTPException(status_code=400, detail="context is required") | |
# Get session for custom crisis keywords | |
session = get_session(session_id) | |
if not session: | |
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}") | |
# Get conversation history for context | |
conversation_history = get_conversation_history(session_id) | |
# Regular prediction | |
prediction = predict_response_type(context) | |
crisis_flag = prediction["crisis_flag"] | |
# Enhanced crisis detection with custom keywords | |
if not crisis_flag and session.crisis_keywords: | |
for keyword in session.crisis_keywords: | |
if keyword.lower() in context.lower(): | |
crisis_flag = True | |
logger.info(f"Crisis flag triggered by custom keyword: {keyword}") | |
break | |
# Generate suggestions | |
suggestion_rag = generate_suggestion_rag(context, prediction["response_type"], crisis_flag) | |
suggestion_direct = generate_suggestion_direct(context, prediction["response_type"], crisis_flag) | |
# Create response | |
response = { | |
"context": context, | |
"response_type": prediction["response_type"], | |
"crisis_flag": crisis_flag, | |
"confidence": prediction["confidence"], | |
"rag_suggestion": suggestion_rag["suggested_response"], | |
"rag_risk_level": suggestion_rag["risk_level"], | |
"direct_suggestion": suggestion_direct["suggested_response"], | |
"direct_risk_level": suggestion_direct["risk_level"], | |
"session_id": session_id | |
} | |
# Save the conversation entry | |
entry = ConversationEntry( | |
session_id=session_id, | |
message=context, | |
sender="patient", | |
suggested_response=suggestion_rag["suggested_response"], | |
response_type=prediction["response_type"], | |
crisis_flag=crisis_flag, | |
risk_level=suggestion_rag["risk_level"] | |
) | |
save_conversation_entry(entry) | |
return response | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error getting session suggestion: {e}") | |
raise HTTPException(status_code=500, detail=f"Error getting session suggestion: {str(e)}") | |
# Feedback Endpoints | |
async def add_feedback(feedback: FeedbackData): | |
"""Add feedback about a suggestion's effectiveness.""" | |
try: | |
feedback_id = save_feedback(feedback) | |
logger.info(f"Added feedback: {feedback_id}") | |
return {"feedback_id": feedback_id} | |
except Exception as e: | |
logger.error(f"Error adding feedback: {e}") | |
raise HTTPException(status_code=500, detail=f"Error adding feedback: {str(e)}") | |
# Tone & Cultural Sensitivity Analysis | |
def analyze_cultural_sensitivity(text: str, cultural_context: Optional[str] = None): | |
"""Analyze text for cultural appropriateness and sensitivity.""" | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are a cultural sensitivity expert. Analyze the following text for cultural appropriateness: | |
Text: {text} | |
Cultural Context: {cultural_context} | |
Provide an analysis of: | |
1. Cultural appropriateness | |
2. Potential bias or insensitivity | |
3. Suggestions for improvement | |
Output in the following format: | |
```json | |
{{ | |
"cultural_appropriateness_score": 0-10, | |
"issues_detected": ["issue1", "issue2"], | |
"suggestions": ["suggestion1", "suggestion2"], | |
"explanation": "Brief explanation of analysis" | |
}} | |
``` | |
""" | |
) | |
analysis_chain = ( | |
{ | |
"text": RunnablePassthrough(), | |
"cultural_context": lambda x: cultural_context if cultural_context else "General" | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = analysis_chain.invoke(text) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error analyzing cultural sensitivity: {e}") | |
raise HTTPException(status_code=500, detail=f"Error analyzing cultural sensitivity: {str(e)}") | |
def analyze_age_appropriateness(text: str, age: Optional[int] = None): | |
"""Analyze text for age-appropriate language.""" | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are an expert in age-appropriate communication. Analyze the following text for age appropriateness: | |
Text: {text} | |
Target Age: {age} | |
Provide an analysis of: | |
1. Age appropriateness | |
2. Complexity level | |
3. Suggestions for improvement | |
Output in the following format: | |
```json | |
{{ | |
"age_appropriateness_score": 0-10, | |
"complexity_level": "Simple/Moderate/Complex", | |
"issues_detected": ["issue1", "issue2"], | |
"suggestions": ["suggestion1", "suggestion2"], | |
"explanation": "Brief explanation of analysis" | |
}} | |
``` | |
""" | |
) | |
analysis_chain = ( | |
{ | |
"text": RunnablePassthrough(), | |
"age": lambda x: str(age) if age else "Adult" | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = analysis_chain.invoke(text) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error analyzing age appropriateness: {e}") | |
raise HTTPException(status_code=500, detail=f"Error analyzing age appropriateness: {str(e)}") | |
async def analyze_text_sensitivity(request: AnalysisRequest): | |
"""Analyze text for cultural sensitivity and age appropriateness.""" | |
try: | |
cultural_analysis = analyze_cultural_sensitivity(request.text, request.cultural_context) | |
age_analysis = analyze_age_appropriateness(request.text, request.patient_age) | |
return { | |
"text": request.text, | |
"cultural_analysis": cultural_analysis, | |
"age_analysis": age_analysis | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing text sensitivity: {e}") | |
raise HTTPException(status_code=500, detail=f"Error analyzing text sensitivity: {str(e)}") | |
# Guided Intervention Workflows | |
def generate_intervention_workflow(issue: str, intervention_type: Optional[str] = None, background: Optional[Dict] = None): | |
"""Generate a structured intervention workflow for a specific issue.""" | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are an expert mental health counselor. Generate a structured intervention workflow for the following patient issue: | |
Patient Issue: {issue} | |
Intervention Type: {intervention_type} | |
Patient Background: {background} | |
Provide a step-by-step intervention plan based on evidence-based practices. Include: | |
1. Initial assessment questions | |
2. Specific techniques to apply | |
3. Homework or practice exercises | |
4. Follow-up guidance | |
Output in the following format: | |
```json | |
{{ | |
"intervention_type": "CBT/DBT/ACT/Mindfulness/etc.", | |
"assessment_questions": ["question1", "question2", "question3"], | |
"techniques": [ | |
{{ | |
"name": "technique name", | |
"description": "brief description", | |
"instructions": "step-by-step instructions" | |
}} | |
], | |
"exercises": [ | |
{{ | |
"name": "exercise name", | |
"description": "brief description", | |
"instructions": "step-by-step instructions" | |
}} | |
], | |
"follow_up": ["follow-up step 1", "follow-up step 2"], | |
"resources": ["resource1", "resource2"] | |
}} | |
``` | |
""" | |
) | |
intervention_chain = ( | |
{ | |
"issue": RunnablePassthrough(), | |
"intervention_type": lambda x: intervention_type if intervention_type else "Best fit", | |
"background": lambda x: str(background) if background else "Not provided" | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = intervention_chain.invoke(issue) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error generating intervention workflow: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}") | |
async def get_intervention_workflow(request: InterventionRequest): | |
"""Get a structured intervention workflow for a specific patient issue.""" | |
try: | |
intervention = generate_intervention_workflow( | |
request.patient_issue, | |
request.intervention_type, | |
request.patient_background | |
) | |
return { | |
"patient_issue": request.patient_issue, | |
"intervention": intervention | |
} | |
except Exception as e: | |
logger.error(f"Error generating intervention workflow: {e}") | |
raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}") | |
async def health_check(): | |
if all([response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm]): | |
return {"status": "healthy", "message": "All models and vector store loaded successfully"} | |
logger.error("Health check failed: One or more components not loaded") | |
raise HTTPException(status_code=500, detail="One or more components failed to load") | |
async def get_metadata(): | |
try: | |
collection = vector_store._client.get_collection("mental_health_conversations") | |
count = collection.count() | |
return {"collection_name": "mental_health_conversations", "document_count": count} | |
except Exception as e: | |
logger.error(f"Error retrieving metadata: {e}") | |
raise HTTPException(status_code=500, detail=f"Error retrieving metadata: {str(e)}") | |
# Database utility functions | |
def save_user_profile(profile: UserProfile): | |
if not profile.user_id: | |
profile.user_id = str(uuid4()) | |
if not profile.created_at: | |
profile.created_at = datetime.now() | |
profile.updated_at = datetime.now() | |
file_path = os.path.join(DATA_DIR, "users", f"{profile.user_id}.json") | |
with open(file_path, "w") as f: | |
# Convert datetime to string for JSON serialization | |
profile_dict = profile.dict() | |
for key in ["created_at", "updated_at"]: | |
if profile_dict[key]: | |
profile_dict[key] = profile_dict[key].isoformat() | |
f.write(json.dumps(profile_dict, indent=2)) | |
return profile | |
def get_user_profile(user_id: str) -> Optional[UserProfile]: | |
file_path = os.path.join(DATA_DIR, "users", f"{user_id}.json") | |
if not os.path.exists(file_path): | |
return None | |
with open(file_path, "r") as f: | |
data = json.loads(f.read()) | |
# Convert string dates back to datetime | |
for key in ["created_at", "updated_at"]: | |
if data[key]: | |
data[key] = datetime.fromisoformat(data[key]) | |
return UserProfile(**data) | |
def save_session(session: SessionData): | |
if not session.session_id: | |
session.session_id = str(uuid4()) | |
if not session.created_at: | |
session.created_at = datetime.now() | |
session.updated_at = datetime.now() | |
file_path = os.path.join(DATA_DIR, "sessions", f"{session.session_id}.json") | |
with open(file_path, "w") as f: | |
# Convert datetime to string for JSON serialization | |
session_dict = session.dict() | |
for key in ["created_at", "updated_at"]: | |
if session_dict[key]: | |
session_dict[key] = session_dict[key].isoformat() | |
f.write(json.dumps(session_dict, indent=2)) | |
return session | |
def get_session(session_id: str) -> Optional[SessionData]: | |
file_path = os.path.join(DATA_DIR, "sessions", f"{session_id}.json") | |
if not os.path.exists(file_path): | |
return None | |
with open(file_path, "r") as f: | |
data = json.loads(f.read()) | |
# Convert string dates back to datetime | |
for key in ["created_at", "updated_at"]: | |
if data[key]: | |
data[key] = datetime.fromisoformat(data[key]) | |
return SessionData(**data) | |
def get_user_sessions(counselor_id: str) -> List[SessionData]: | |
sessions = [] | |
sessions_dir = os.path.join(DATA_DIR, "sessions") | |
for filename in os.listdir(sessions_dir): | |
if not filename.endswith(".json"): | |
continue | |
file_path = os.path.join(sessions_dir, filename) | |
with open(file_path, "r") as f: | |
data = json.loads(f.read()) | |
if data["counselor_id"] == counselor_id: | |
for key in ["created_at", "updated_at"]: | |
if data[key]: | |
data[key] = datetime.fromisoformat(data[key]) | |
sessions.append(SessionData(**data)) | |
return sessions | |
def save_conversation_entry(entry: ConversationEntry): | |
conversation_dir = os.path.join(DATA_DIR, "conversations", entry.session_id) | |
os.makedirs(conversation_dir, exist_ok=True) | |
if not entry.timestamp: | |
entry.timestamp = datetime.now() | |
entry_id = str(uuid4()) | |
file_path = os.path.join(conversation_dir, f"{entry_id}.json") | |
with open(file_path, "w") as f: | |
# Convert datetime to string for JSON serialization | |
entry_dict = entry.dict() | |
entry_dict["entry_id"] = entry_id | |
if entry_dict["timestamp"]: | |
entry_dict["timestamp"] = entry_dict["timestamp"].isoformat() | |
f.write(json.dumps(entry_dict, indent=2)) | |
return entry_id | |
def get_conversation_history(session_id: str) -> List[ConversationEntry]: | |
conversation_dir = os.path.join(DATA_DIR, "conversations", session_id) | |
if not os.path.exists(conversation_dir): | |
return [] | |
entries = [] | |
for filename in os.listdir(conversation_dir): | |
if not filename.endswith(".json"): | |
continue | |
file_path = os.path.join(conversation_dir, filename) | |
with open(file_path, "r") as f: | |
data = json.loads(f.read()) | |
if data["timestamp"]: | |
data["timestamp"] = datetime.fromisoformat(data["timestamp"]) | |
entries.append(ConversationEntry(**data)) | |
# Sort by timestamp | |
entries.sort(key=lambda x: x.timestamp) | |
return entries | |
def save_feedback(feedback: FeedbackData): | |
feedback_id = str(uuid4()) | |
file_path = os.path.join(DATA_DIR, "feedback", f"{feedback_id}.json") | |
with open(file_path, "w") as f: | |
feedback_dict = feedback.dict() | |
feedback_dict["feedback_id"] = feedback_id | |
feedback_dict["timestamp"] = datetime.now().isoformat() | |
f.write(json.dumps(feedback_dict, indent=2)) | |
return feedback_id | |
# Multi-modal Input Support | |
async def process_multimodal_input(input_data: MultiModalInput): | |
"""Process multi-modal input (text, audio, video).""" | |
try: | |
if input_data.input_type not in ["text", "audio", "video"]: | |
raise HTTPException(status_code=400, detail=f"Unsupported input type: {input_data.input_type}") | |
# For now, handle text directly and simulate processing for audio/video | |
if input_data.input_type == "text": | |
# Process text normally | |
prediction = predict_response_type(input_data.content) | |
return { | |
"input_type": "text", | |
"processed_content": input_data.content, | |
"analysis": { | |
"response_type": prediction["response_type"], | |
"crisis_flag": prediction["crisis_flag"], | |
"confidence": prediction["confidence"] | |
} | |
} | |
elif input_data.input_type == "audio": | |
# Simulate audio transcription and emotion detection | |
# In a production system, this would use a speech-to-text API and emotion analysis | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
Simulate audio processing for this description: {content} | |
Generate a simulated transcription and emotion detection as if this were real audio. | |
Output in the following format: | |
```json | |
{{ | |
"transcription": "Simulated transcription of the audio", | |
"emotion_detected": "primary emotion", | |
"secondary_emotions": ["emotion1", "emotion2"], | |
"confidence": 0.85 | |
}} | |
``` | |
""" | |
) | |
process_chain = prompt_template | llm | |
response = process_chain.invoke({"content": input_data.content}) | |
audio_result = eval(response.content.strip("```json\n").strip("\n```")) | |
# Now process the transcription | |
prediction = predict_response_type(audio_result["transcription"]) | |
return { | |
"input_type": "audio", | |
"processed_content": audio_result["transcription"], | |
"emotion_analysis": { | |
"primary_emotion": audio_result["emotion_detected"], | |
"secondary_emotions": audio_result["secondary_emotions"], | |
"confidence": audio_result["confidence"] | |
}, | |
"analysis": { | |
"response_type": prediction["response_type"], | |
"crisis_flag": prediction["crisis_flag"], | |
"confidence": prediction["confidence"] | |
} | |
} | |
elif input_data.input_type == "video": | |
# Simulate video analysis | |
# In a production system, this would use video analytics API | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
Simulate video processing for this description: {content} | |
Generate a simulated analysis as if this were real video with facial expressions and body language. | |
Output in the following format: | |
```json | |
{{ | |
"transcription": "Simulated transcription of speech in the video", | |
"facial_expressions": ["expression1", "expression2"], | |
"body_language": ["posture observation", "gesture observation"], | |
"primary_emotion": "primary emotion", | |
"confidence": 0.80 | |
}} | |
``` | |
""" | |
) | |
process_chain = prompt_template | llm | |
response = process_chain.invoke({"content": input_data.content}) | |
video_result = eval(response.content.strip("```json\n").strip("\n```")) | |
# Process the transcription | |
prediction = predict_response_type(video_result["transcription"]) | |
return { | |
"input_type": "video", | |
"processed_content": video_result["transcription"], | |
"nonverbal_analysis": { | |
"facial_expressions": video_result["facial_expressions"], | |
"body_language": video_result["body_language"], | |
"primary_emotion": video_result["primary_emotion"], | |
"confidence": video_result["confidence"] | |
}, | |
"analysis": { | |
"response_type": prediction["response_type"], | |
"crisis_flag": prediction["crisis_flag"], | |
"confidence": prediction["confidence"] | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error processing multimodal input: {e}") | |
raise HTTPException(status_code=500, detail=f"Error processing multimodal input: {str(e)}") | |
# Therapeutic Technique Suggestions | |
def suggest_therapeutic_techniques(context: str, technique_type: Optional[str] = None): | |
"""Suggest specific therapeutic techniques based on the patient context.""" | |
prompt_template = ChatPromptTemplate.from_template( | |
""" | |
You are an expert mental health professional with extensive knowledge of therapeutic techniques. Based on the following patient context, suggest therapeutic techniques that would be appropriate: | |
Patient Context: {context} | |
Technique Type (if specified): {technique_type} | |
Suggest specific therapeutic techniques, exercises, or interventions that would be helpful for this patient. Include: | |
1. Name of technique | |
2. Brief description | |
3. How to apply it in this specific case | |
4. Expected benefits | |
Provide a range of options from different therapeutic approaches (CBT, DBT, ACT, mindfulness, motivational interviewing, etc.) unless a specific technique type was requested. | |
Output in the following format: | |
```json | |
{{ | |
"primary_approach": "The most appropriate therapeutic approach", | |
"techniques": [ | |
{{ | |
"name": "Technique name", | |
"approach": "CBT/DBT/ACT/etc.", | |
"description": "Brief description", | |
"application": "How to apply to this specific case", | |
"benefits": "Expected benefits" | |
}} | |
], | |
"rationale": "Brief explanation of why these techniques were selected" | |
}} | |
``` | |
""" | |
) | |
technique_chain = ( | |
{ | |
"context": RunnablePassthrough(), | |
"technique_type": lambda x: technique_type if technique_type else "Any appropriate" | |
} | |
| prompt_template | |
| llm | |
) | |
try: | |
response = technique_chain.invoke(context) | |
return eval(response.content.strip("```json\n").strip("\n```")) | |
except Exception as e: | |
logger.error(f"Error suggesting therapeutic techniques: {e}") | |
raise HTTPException(status_code=500, detail=f"Error suggesting therapeutic techniques: {str(e)}") | |
async def get_therapeutic_techniques(request: dict): | |
"""Get suggested therapeutic techniques for a patient context.""" | |
try: | |
context = request.get("context") | |
if not context: | |
raise HTTPException(status_code=400, detail="context is required") | |
technique_type = request.get("technique_type") | |
techniques = suggest_therapeutic_techniques(context, technique_type) | |
return { | |
"context": context, | |
"techniques": techniques | |
} | |
except Exception as e: | |
logger.error(f"Error getting therapeutic techniques: {e}") | |
raise HTTPException(status_code=500, detail=f"Error getting therapeutic techniques: {str(e)}") | |
# Ethical AI Guardrails - Confidence Indicator | |
async def get_suggestion_with_confidence(context: PatientContext): | |
"""Get suggestion with detailed confidence indicators and uncertainty flags.""" | |
try: | |
# Get standard prediction | |
prediction = predict_response_type(context.context) | |
# Set confidence thresholds | |
high_confidence = 0.8 | |
medium_confidence = 0.6 | |
# Determine confidence level | |
confidence_value = prediction["confidence"] | |
if confidence_value >= high_confidence: | |
confidence_level = "High" | |
elif confidence_value >= medium_confidence: | |
confidence_level = "Medium" | |
else: | |
confidence_level = "Low" | |
# Analyze for potential biases | |
bias_prompt = ChatPromptTemplate.from_template( | |
""" | |
You are an AI ethics expert. Analyze the following patient context and proposed response type for potential biases: | |
Patient Context: {context} | |
Predicted Response Type: {response_type} | |
Identify any potential biases in interpretation or response. Consider gender, cultural, socioeconomic, and other potential biases. | |
Output in the following format: | |
```json | |
{{ | |
"bias_detected": true/false, | |
"bias_types": ["bias type 1", "bias type 2"], | |
"explanation": "Brief explanation of potential biases" | |
}} | |
``` | |
""" | |
) | |
bias_chain = ( | |
{ | |
"context": lambda x: context.context, | |
"response_type": lambda x: prediction["response_type"] | |
} | |
| bias_prompt | |
| llm | |
) | |
bias_analysis = eval(bias_chain.invoke({}).content.strip("```json\n").strip("\n```")) | |
# Generate suggestions | |
suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"]) | |
suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"]) | |
return { | |
"context": context.context, | |
"response_type": prediction["response_type"], | |
"crisis_flag": prediction["crisis_flag"], | |
"confidence": { | |
"value": prediction["confidence"], | |
"level": confidence_level, | |
"uncertainty_flag": confidence_level == "Low" | |
}, | |
"bias_analysis": bias_analysis, | |
"rag_suggestion": suggestion_rag["suggested_response"], | |
"rag_risk_level": suggestion_rag["risk_level"], | |
"direct_suggestion": suggestion_direct["suggested_response"], | |
"direct_risk_level": suggestion_direct["risk_level"], | |
"attribution": { | |
"ai_generated": True, | |
"model_version": "Mental Health Counselor API v2.0", | |
"human_reviewed": False | |
} | |
} | |
except Exception as e: | |
logger.error(f"Error getting suggestion with confidence: {e}") | |
raise HTTPException(status_code=500, detail=f"Error getting suggestion with confidence: {str(e)}") | |
# Text to Speech with Eleven Labs API | |
async def text_to_speech(request: dict): | |
"""Convert text to speech using Eleven Labs API.""" | |
try: | |
text = request.get("text") | |
voice_id = request.get("voice_id", "pNInz6obpgDQGcFmaJgB") # Default to "Adam" voice | |
if not text: | |
raise HTTPException(status_code=400, detail="Text is required") | |
# Get API key from environment | |
api_key = os.getenv("ELEVEN_API_KEY") | |
if not api_key: | |
raise HTTPException(status_code=500, detail="Eleven Labs API key not configured") | |
# Prepare the request to Eleven Labs | |
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}" | |
headers = { | |
"Accept": "audio/mpeg", | |
"Content-Type": "application/json", | |
"xi-api-key": api_key | |
} | |
payload = { | |
"text": text, | |
"model_id": "eleven_multilingual_v2", | |
"voice_settings": { | |
"stability": 0.5, | |
"similarity_boost": 0.75 | |
} | |
} | |
# Make the request to Eleven Labs | |
response = requests.post(url, json=payload, headers=headers) | |
if response.status_code != 200: | |
logger.error(f"Eleven Labs API error: {response.text}") | |
raise HTTPException(status_code=response.status_code, | |
detail=f"Eleven Labs API error: {response.text}") | |
# Return audio as streaming response | |
return StreamingResponse( | |
BytesIO(response.content), | |
media_type="audio/mpeg" | |
) | |
except Exception as e: | |
logger.error(f"Error in text-to-speech: {str(e)}") | |
if not isinstance(e, HTTPException): | |
raise HTTPException(status_code=500, detail=f"Text-to-speech error: {str(e)}") | |
raise e | |
# Multimedia file processing (speech to text) | |
async def process_audio_input( | |
audio: UploadFile = File(...), | |
session_id: str = Form(...) | |
): | |
"""Process audio input for speech-to-text using Eleven Labs.""" | |
try: | |
# Get API key from environment | |
api_key = os.getenv("ELEVEN_API_KEY") | |
if not api_key: | |
raise HTTPException(status_code=500, detail="Eleven Labs API key not configured") | |
# Read the audio file content | |
audio_content = await audio.read() | |
# Call Eleven Labs Speech-to-Text API | |
url = "https://api.elevenlabs.io/v1/speech-to-text" | |
headers = { | |
"xi-api-key": api_key | |
} | |
# Create form data with the audio file | |
files = { | |
'audio': ('audio.webm', audio_content, 'audio/webm') | |
} | |
data = { | |
'model_id': 'whisper-1' # Using Whisper model | |
} | |
# Make the request to Eleven Labs | |
response = requests.post(url, headers=headers, files=files, data=data) | |
if response.status_code != 200: | |
logger.error(f"Eleven Labs API error: {response.text}") | |
raise HTTPException(status_code=response.status_code, | |
detail=f"Eleven Labs API error: {response.text}") | |
result = response.json() | |
# Extract the transcribed text | |
text = result.get('text', '') | |
# Return the transcribed text | |
return { | |
"text": text, | |
"session_id": session_id | |
} | |
except Exception as e: | |
logger.error(f"Error processing audio: {str(e)}") | |
if not isinstance(e, HTTPException): | |
raise HTTPException(status_code=500, detail=f"Audio processing error: {str(e)}") | |
raise e | |
# Add a custom encoder for bytes objects to prevent UTF-8 decode errors | |
def custom_encoder(obj): | |
if isinstance(obj, bytes): | |
try: | |
return obj.decode('utf-8') | |
except UnicodeDecodeError: | |
return base64.b64encode(obj).decode('ascii') | |
raise TypeError(f"Object of type {type(obj)} is not JSON serializable") | |
# Override the jsonable_encoder function to handle bytes properly | |
from fastapi.encoders import jsonable_encoder as original_jsonable_encoder | |
def safe_jsonable_encoder(*args, **kwargs): | |
try: | |
return original_jsonable_encoder(*args, **kwargs) | |
except UnicodeDecodeError: | |
# If the standard encoder fails with a decode error, | |
# ensure all bytes are properly handled | |
if args and isinstance(args[0], bytes): | |
return custom_encoder(args[0]) | |
raise | |
# Monkey patch the jsonable_encoder in FastAPI | |
import fastapi.encoders | |
fastapi.encoders.jsonable_encoder = safe_jsonable_encoder | |