Spaces:
Sleeping
Sleeping
Said Lfagrouche
commited on
Commit
·
e62f11c
1
Parent(s):
a059233
Prepare for Hugging Face Spaces deployment with simplified configuration
Browse files- Dockerfile +1 -1
- app.py +45 -1333
- requirements.txt +22 -23
Dockerfile
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
FROM python:3.
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
|
|
1 |
+
FROM python:3.9
|
2 |
|
3 |
WORKDIR /app
|
4 |
|
app.py
CHANGED
@@ -1,34 +1,8 @@
|
|
1 |
-
|
2 |
-
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
3 |
from pydantic import BaseModel
|
4 |
-
import pandas as pd
|
5 |
-
import numpy as np
|
6 |
-
import joblib
|
7 |
-
import re
|
8 |
-
import nltk
|
9 |
-
from nltk.tokenize import word_tokenize
|
10 |
-
from nltk.stem import WordNetLemmatizer
|
11 |
-
from nltk.corpus import stopwords
|
12 |
-
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
|
13 |
-
import chromadb
|
14 |
-
from chromadb.config import Settings
|
15 |
-
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
16 |
-
from langchain_chroma import Chroma
|
17 |
-
from openai import OpenAI
|
18 |
import os
|
19 |
from dotenv import load_dotenv
|
20 |
-
from langsmith import Client, traceable
|
21 |
-
from langchain_core.runnables import RunnablePassthrough
|
22 |
-
from langchain_core.prompts import ChatPromptTemplate
|
23 |
import logging
|
24 |
-
from typing import List, Dict, Optional, Any, Union, Annotated
|
25 |
-
from datetime import datetime
|
26 |
-
from uuid import uuid4, UUID
|
27 |
-
import json
|
28 |
-
import requests
|
29 |
-
from fastapi.responses import StreamingResponse
|
30 |
-
from io import BytesIO
|
31 |
-
import base64
|
32 |
|
33 |
# Set up logging
|
34 |
logging.basicConfig(level=logging.INFO)
|
@@ -37,15 +11,10 @@ logger = logging.getLogger(__name__)
|
|
37 |
# Load environment variables
|
38 |
load_dotenv()
|
39 |
|
40 |
-
# Download required NLTK data
|
41 |
-
nltk.download('punkt')
|
42 |
-
nltk.download('wordnet')
|
43 |
-
nltk.download('stopwords')
|
44 |
-
|
45 |
# Initialize FastAPI app
|
46 |
app = FastAPI(title="Mental Health Counselor API")
|
47 |
|
48 |
-
# Initialize global storage
|
49 |
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
50 |
os.makedirs(DATA_DIR, exist_ok=True)
|
51 |
os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
|
@@ -53,1318 +22,61 @@ os.makedirs(os.path.join(DATA_DIR, "sessions"), exist_ok=True)
|
|
53 |
os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
|
54 |
os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
analyzer = SentimentIntensityAnalyzer()
|
60 |
-
output_dir = "mental_health_model_artifacts"
|
61 |
-
|
62 |
-
# Global variables for models and vector store
|
63 |
-
response_clf = None
|
64 |
-
crisis_clf = None
|
65 |
-
vectorizer = None
|
66 |
-
le = None
|
67 |
-
selector = None
|
68 |
-
lda = None
|
69 |
-
vector_store = None
|
70 |
-
llm = None
|
71 |
-
openai_client = None
|
72 |
-
langsmith_client = None
|
73 |
-
|
74 |
-
# Load models and initialize ChromaDB at startup
|
75 |
-
@app.on_event("startup")
|
76 |
-
async def startup_event():
|
77 |
-
global response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm, openai_client, langsmith_client
|
78 |
-
|
79 |
-
# Check environment variables
|
80 |
-
if not os.environ.get("OPENAI_API_KEY"):
|
81 |
-
logger.error("OPENAI_API_KEY not set in .env file")
|
82 |
-
raise HTTPException(status_code=500, detail="OPENAI_API_KEY not set in .env file")
|
83 |
-
if not os.environ.get("LANGCHAIN_API_KEY"):
|
84 |
-
logger.error("LANGCHAIN_API_KEY not set in .env file")
|
85 |
-
raise HTTPException(status_code=500, detail="LANGCHAIN_API_KEY not set in .env file")
|
86 |
-
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
87 |
-
os.environ["LANGCHAIN_PROJECT"] = "MentalHealthCounselorPOC"
|
88 |
-
|
89 |
-
# Initialize LangSmith client
|
90 |
-
logger.info("Initializing LangSmith client")
|
91 |
-
langsmith_client = Client()
|
92 |
-
|
93 |
-
# Load saved components
|
94 |
-
logger.info("Loading model artifacts")
|
95 |
-
try:
|
96 |
-
response_clf = joblib.load(f"{output_dir}/response_type_classifier.pkl")
|
97 |
-
crisis_clf = joblib.load(f"{output_dir}/crisis_classifier.pkl")
|
98 |
-
vectorizer = joblib.load(f"{output_dir}/tfidf_vectorizer.pkl")
|
99 |
-
le = joblib.load(f"{output_dir}/label_encoder.pkl")
|
100 |
-
selector = joblib.load(f"{output_dir}/feature_selector.pkl")
|
101 |
-
|
102 |
-
try:
|
103 |
-
lda = joblib.load(f"{output_dir}/lda_model.pkl")
|
104 |
-
except Exception as lda_error:
|
105 |
-
logger.warning(f"Failed to load LDA model: {lda_error}. Creating placeholder model.")
|
106 |
-
from sklearn.decomposition import LatentDirichletAllocation
|
107 |
-
lda = LatentDirichletAllocation(n_components=10, random_state=42)
|
108 |
-
# Note: Placeholder is untrained; retrain for accurate results
|
109 |
-
|
110 |
-
except FileNotFoundError as e:
|
111 |
-
logger.error(f"Missing model artifact: {e}")
|
112 |
-
raise HTTPException(status_code=500, detail=f"Missing model artifact: {e}")
|
113 |
-
|
114 |
-
# Initialize ChromaDB
|
115 |
-
chroma_db_path = f"{output_dir}/chroma_db"
|
116 |
-
if not os.path.exists(chroma_db_path):
|
117 |
-
logger.error(f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
|
118 |
-
raise HTTPException(status_code=500, detail=f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
|
119 |
-
|
120 |
-
try:
|
121 |
-
logger.info("Initializing ChromaDB")
|
122 |
-
chroma_client = chromadb.PersistentClient(
|
123 |
-
path=chroma_db_path,
|
124 |
-
settings=Settings(anonymized_telemetry=False)
|
125 |
-
)
|
126 |
-
|
127 |
-
embeddings = OpenAIEmbeddings(
|
128 |
-
model="text-embedding-ada-002",
|
129 |
-
api_key=os.environ["OPENAI_API_KEY"],
|
130 |
-
disallowed_special=(),
|
131 |
-
chunk_size=1000
|
132 |
-
)
|
133 |
-
global vector_store
|
134 |
-
vector_store = Chroma(
|
135 |
-
client=chroma_client,
|
136 |
-
collection_name="mental_health_conversations",
|
137 |
-
embedding_function=embeddings
|
138 |
-
)
|
139 |
-
except Exception as e:
|
140 |
-
logger.error(f"Error initializing ChromaDB: {e}")
|
141 |
-
raise HTTPException(status_code=500, detail=f"Error initializing ChromaDB: {e}")
|
142 |
-
|
143 |
-
# Initialize OpenAI client and LLM
|
144 |
-
logger.info("Initializing OpenAI client and LLM")
|
145 |
-
global openai_client, llm
|
146 |
-
openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
147 |
-
llm = ChatOpenAI(
|
148 |
-
model="gpt-4o-mini",
|
149 |
-
temperature=0.7,
|
150 |
-
api_key=os.environ["OPENAI_API_KEY"]
|
151 |
-
)
|
152 |
-
|
153 |
-
# Pydantic model for request
|
154 |
-
class PatientContext(BaseModel):
|
155 |
-
context: str
|
156 |
-
|
157 |
-
# New Pydantic models for expanded API functionality
|
158 |
-
class UserProfile(BaseModel):
|
159 |
-
user_id: Optional[str] = None
|
160 |
-
username: str
|
161 |
-
name: str
|
162 |
-
role: str = "counselor"
|
163 |
-
specializations: List[str] = []
|
164 |
-
years_experience: Optional[int] = None
|
165 |
-
custom_crisis_keywords: List[str] = []
|
166 |
-
preferences: Dict[str, Any] = {}
|
167 |
-
created_at: Optional[datetime] = None
|
168 |
-
updated_at: Optional[datetime] = None
|
169 |
-
|
170 |
-
class SessionData(BaseModel):
|
171 |
-
session_id: Optional[str] = None
|
172 |
-
counselor_id: str
|
173 |
-
patient_identifier: str # Anonymized ID
|
174 |
-
session_notes: str = ""
|
175 |
-
session_preferences: Dict[str, Any] = {}
|
176 |
-
crisis_keywords: List[str] = []
|
177 |
-
created_at: Optional[datetime] = None
|
178 |
-
updated_at: Optional[datetime] = None
|
179 |
-
|
180 |
-
class ConversationEntry(BaseModel):
|
181 |
-
session_id: str
|
182 |
-
message: str
|
183 |
-
sender: str # 'patient' or 'counselor'
|
184 |
-
timestamp: Optional[datetime] = None
|
185 |
-
suggested_response: Optional[str] = None
|
186 |
-
response_type: Optional[str] = None
|
187 |
-
crisis_flag: bool = False
|
188 |
-
risk_level: Optional[str] = None
|
189 |
-
|
190 |
-
class FeedbackData(BaseModel):
|
191 |
-
suggestion_id: str
|
192 |
-
counselor_id: str
|
193 |
-
rating: int # 1-5 scale
|
194 |
-
was_effective: bool
|
195 |
-
comments: Optional[str] = None
|
196 |
-
|
197 |
-
class AnalysisRequest(BaseModel):
|
198 |
-
text: str
|
199 |
-
patient_background: Optional[Dict[str, Any]] = None
|
200 |
-
patient_age: Optional[int] = None
|
201 |
-
cultural_context: Optional[str] = None
|
202 |
-
|
203 |
-
class MultiModalInput(BaseModel):
|
204 |
-
session_id: str
|
205 |
-
counselor_id: str
|
206 |
-
input_type: str # 'text', 'audio', 'video'
|
207 |
-
content: str # Text content or file path/url
|
208 |
-
metadata: Dict[str, Any] = {}
|
209 |
-
|
210 |
-
class InterventionRequest(BaseModel):
|
211 |
-
patient_issue: str
|
212 |
-
patient_background: Optional[Dict[str, Any]] = None
|
213 |
-
intervention_type: Optional[str] = None # e.g., 'CBT', 'DBT', 'mindfulness'
|
214 |
-
|
215 |
-
# Text preprocessing function
|
216 |
-
@traceable(run_type="tool", name="Clean Text")
|
217 |
-
def clean_text(text):
|
218 |
-
if pd.isna(text):
|
219 |
-
return ""
|
220 |
-
text = str(text).lower()
|
221 |
-
text = re.sub(r"[^a-zA-Z']", " ", text)
|
222 |
-
tokens = word_tokenize(text)
|
223 |
-
tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
|
224 |
-
return " ".join(tokens)
|
225 |
-
|
226 |
-
# Feature engineering function
|
227 |
-
@traceable(run_type="tool", name="Engineer Features")
|
228 |
-
def engineer_features(context, response=""):
|
229 |
-
context_clean = clean_text(context)
|
230 |
-
context_len = len(context_clean.split())
|
231 |
-
context_vader = analyzer.polarity_scores(context)['compound']
|
232 |
-
context_questions = context.count('?')
|
233 |
-
crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
|
234 |
-
context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
|
235 |
-
|
236 |
-
context_tfidf = vectorizer.transform([context_clean]).toarray()
|
237 |
-
tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
|
238 |
-
response_tfidf = np.zeros_like(context_tfidf)
|
239 |
-
|
240 |
-
lda_topics = lda.transform(context_tfidf)
|
241 |
-
|
242 |
-
feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
|
243 |
-
[f"topic_{i}" for i in range(10)] + tfidf_cols + \
|
244 |
-
[f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
|
245 |
-
|
246 |
-
features = pd.DataFrame({
|
247 |
-
"context_len": [context_len],
|
248 |
-
"context_vader": [context_vader],
|
249 |
-
"context_questions": [context_questions],
|
250 |
-
**{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
|
251 |
-
**{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
|
252 |
-
**{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
|
253 |
-
})
|
254 |
-
|
255 |
-
crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
|
256 |
-
crisis_flag = crisis_clf.predict(crisis_features)[0]
|
257 |
-
if context_crisis_score > 0:
|
258 |
-
crisis_flag = 1
|
259 |
-
features["crisis_flag"] = crisis_flag
|
260 |
-
|
261 |
-
return features, feature_cols
|
262 |
-
|
263 |
-
# Prediction function
|
264 |
-
@traceable(run_type="chain", name="Predict Response Type")
|
265 |
-
def predict_response_type(context):
|
266 |
-
features, feature_cols = engineer_features(context)
|
267 |
-
selected_features = selector.transform(features[feature_cols])
|
268 |
-
pred_encoded = response_clf.predict(selected_features)[0]
|
269 |
-
pred_label = le.inverse_transform([pred_encoded])[0]
|
270 |
-
confidence = response_clf.predict_proba(selected_features)[0].max()
|
271 |
-
|
272 |
-
if "?" in context and context.count("?") > 0:
|
273 |
-
pred_label = "Question"
|
274 |
-
if "trying" in context.lower() and "hard" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
|
275 |
-
pred_label = "Validation"
|
276 |
-
if "trying" in context.lower() and "positive" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
|
277 |
-
pred_label = "Question"
|
278 |
-
|
279 |
-
crisis_flag = bool(features["crisis_flag"].iloc[0])
|
280 |
-
|
281 |
-
return {
|
282 |
-
"response_type": pred_label,
|
283 |
-
"crisis_flag": crisis_flag,
|
284 |
-
"confidence": confidence,
|
285 |
-
"features": features.to_dict()
|
286 |
-
}
|
287 |
-
|
288 |
-
# RAG suggestion function
|
289 |
-
@traceable(run_type="chain", name="RAG Suggestion")
|
290 |
-
def generate_suggestion_rag(context, response_type, crisis_flag):
|
291 |
-
results = vector_store.similarity_search_with_score(context, k=3)
|
292 |
-
retrieved_contexts = [
|
293 |
-
f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
|
294 |
-
for res in results
|
295 |
-
]
|
296 |
-
|
297 |
-
prompt_template = ChatPromptTemplate.from_template(
|
298 |
-
"""
|
299 |
-
You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
|
300 |
-
|
301 |
-
Patient Situation: {context}
|
302 |
-
|
303 |
-
Predicted Response Type: {response_type}
|
304 |
-
Crisis Flag: {crisis_flag}
|
305 |
-
|
306 |
-
Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
|
307 |
-
|
308 |
-
For reference, here are similar cases from past conversations:
|
309 |
-
{retrieved_contexts}
|
310 |
-
|
311 |
-
Guidelines:
|
312 |
-
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
|
313 |
-
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
|
314 |
-
- For 'Advice', provide practical, actionable suggestions.
|
315 |
-
- For 'Question', pose an open-ended question to encourage further discussion.
|
316 |
-
- For 'Validation', affirm the patient's efforts or feelings.
|
317 |
-
|
318 |
-
Output in the following format:
|
319 |
-
```json
|
320 |
-
{{
|
321 |
-
"suggested_response": "Your suggested response here",
|
322 |
-
"risk_level": "Low/Moderate/High"
|
323 |
-
}}
|
324 |
-
```
|
325 |
-
"""
|
326 |
-
)
|
327 |
-
|
328 |
-
rag_chain = (
|
329 |
-
{
|
330 |
-
"context": RunnablePassthrough(),
|
331 |
-
"response_type": lambda x: response_type,
|
332 |
-
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
|
333 |
-
"retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
|
334 |
-
}
|
335 |
-
| prompt_template
|
336 |
-
| llm
|
337 |
-
)
|
338 |
-
|
339 |
-
try:
|
340 |
-
response = rag_chain.invoke(context)
|
341 |
-
return eval(response.content.strip("```json\n").strip("\n```"))
|
342 |
-
except Exception as e:
|
343 |
-
logger.error(f"Error generating RAG suggestion: {e}")
|
344 |
-
raise HTTPException(status_code=500, detail=f"Error generating RAG suggestion: {str(e)}")
|
345 |
-
|
346 |
-
# Direct suggestion function
|
347 |
-
@traceable(run_type="chain", name="Direct Suggestion")
|
348 |
-
def generate_suggestion_direct(context, response_type, crisis_flag):
|
349 |
-
prompt_template = ChatPromptTemplate.from_template(
|
350 |
-
"""
|
351 |
-
You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
|
352 |
-
|
353 |
-
Patient Situation: {context}
|
354 |
-
|
355 |
-
Predicted Response Type: {response_type}
|
356 |
-
Crisis Flag: {crisis_flag}
|
357 |
-
|
358 |
-
Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
|
359 |
-
|
360 |
-
Guidelines:
|
361 |
-
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
|
362 |
-
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
|
363 |
-
- For 'Advice', provide practical, actionable suggestions.
|
364 |
-
- For 'Question', pose an open-ended question to encourage further discussion.
|
365 |
-
- For 'Validation', affirm the patient's efforts or feelings.
|
366 |
-
- Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
|
367 |
-
|
368 |
-
Output in the following format:
|
369 |
-
```json
|
370 |
-
{{
|
371 |
-
"suggested_response": "Your suggested response here",
|
372 |
-
"risk_level": "Low/Moderate/High"
|
373 |
-
}}
|
374 |
-
```
|
375 |
-
"""
|
376 |
-
)
|
377 |
-
|
378 |
-
direct_chain = (
|
379 |
-
{
|
380 |
-
"context": RunnablePassthrough(),
|
381 |
-
"response_type": lambda x: response_type,
|
382 |
-
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
|
383 |
-
}
|
384 |
-
| prompt_template
|
385 |
-
| llm
|
386 |
-
)
|
387 |
-
|
388 |
-
try:
|
389 |
-
response = direct_chain.invoke(context)
|
390 |
-
return eval(response.content.strip("```json\n").strip("\n```"))
|
391 |
-
except Exception as e:
|
392 |
-
logger.error(f"Error generating direct suggestion: {e}")
|
393 |
-
raise HTTPException(status_code=500, detail=f"Error generating direct suggestion: {str(e)}")
|
394 |
-
|
395 |
-
# User Profile Endpoints
|
396 |
-
@app.post("/users/create", response_model=UserProfile)
|
397 |
-
async def create_user(profile: UserProfile):
|
398 |
-
"""Create a new counselor profile with preferences and specializations."""
|
399 |
-
try:
|
400 |
-
saved_profile = save_user_profile(profile)
|
401 |
-
logger.info(f"Created user profile: {saved_profile.user_id}")
|
402 |
-
return saved_profile
|
403 |
-
except Exception as e:
|
404 |
-
logger.error(f"Error creating user profile: {e}")
|
405 |
-
raise HTTPException(status_code=500, detail=f"Error creating user profile: {str(e)}")
|
406 |
-
|
407 |
-
@app.get("/users/{user_id}", response_model=UserProfile)
|
408 |
-
async def get_user(user_id: str):
|
409 |
-
"""Get a counselor profile by user ID."""
|
410 |
-
profile = get_user_profile(user_id)
|
411 |
-
if not profile:
|
412 |
-
raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
|
413 |
-
return profile
|
414 |
-
|
415 |
-
@app.put("/users/{user_id}", response_model=UserProfile)
|
416 |
-
async def update_user(user_id: str, profile_update: UserProfile):
|
417 |
-
"""Update a counselor profile."""
|
418 |
-
existing_profile = get_user_profile(user_id)
|
419 |
-
if not existing_profile:
|
420 |
-
raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
|
421 |
-
|
422 |
-
# Preserve the original user_id
|
423 |
-
profile_update.user_id = user_id
|
424 |
-
# Preserve the original created_at timestamp
|
425 |
-
profile_update.created_at = existing_profile.created_at
|
426 |
-
|
427 |
-
try:
|
428 |
-
updated_profile = save_user_profile(profile_update)
|
429 |
-
logger.info(f"Updated user profile: {user_id}")
|
430 |
-
return updated_profile
|
431 |
-
except Exception as e:
|
432 |
-
logger.error(f"Error updating user profile: {e}")
|
433 |
-
raise HTTPException(status_code=500, detail=f"Error updating user profile: {str(e)}")
|
434 |
-
|
435 |
-
# Session Management Endpoints
|
436 |
-
@app.post("/sessions/create", response_model=SessionData)
|
437 |
-
async def create_session(session_data: SessionData):
|
438 |
-
"""Create a new session with patient identifier (anonymized)."""
|
439 |
-
try:
|
440 |
-
# Verify counselor exists
|
441 |
-
counselor = get_user_profile(session_data.counselor_id)
|
442 |
-
if not counselor:
|
443 |
-
raise HTTPException(status_code=404, detail=f"Counselor not found: {session_data.counselor_id}")
|
444 |
-
|
445 |
-
# If counselor has custom crisis keywords, add them to the session
|
446 |
-
if counselor.custom_crisis_keywords:
|
447 |
-
session_data.crisis_keywords.extend(counselor.custom_crisis_keywords)
|
448 |
-
|
449 |
-
saved_session = save_session(session_data)
|
450 |
-
logger.info(f"Created session: {saved_session.session_id}")
|
451 |
-
return saved_session
|
452 |
-
except HTTPException:
|
453 |
-
raise
|
454 |
-
except Exception as e:
|
455 |
-
logger.error(f"Error creating session: {e}")
|
456 |
-
raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}")
|
457 |
-
|
458 |
-
@app.get("/sessions/{session_id}", response_model=SessionData)
|
459 |
-
async def get_session_by_id(session_id: str):
|
460 |
-
"""Get a session by ID."""
|
461 |
-
session = get_session(session_id)
|
462 |
-
if not session:
|
463 |
-
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
|
464 |
-
return session
|
465 |
-
|
466 |
-
@app.get("/sessions/counselor/{counselor_id}", response_model=List[SessionData])
|
467 |
-
async def get_counselor_sessions(counselor_id: str):
|
468 |
-
"""Get all sessions for a counselor."""
|
469 |
-
sessions = get_user_sessions(counselor_id)
|
470 |
-
return sessions
|
471 |
-
|
472 |
-
@app.put("/sessions/{session_id}", response_model=SessionData)
|
473 |
-
async def update_session(session_id: str, session_update: SessionData):
|
474 |
-
"""Update a session."""
|
475 |
-
existing_session = get_session(session_id)
|
476 |
-
if not existing_session:
|
477 |
-
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
|
478 |
-
|
479 |
-
# Preserve the original session_id and created_at
|
480 |
-
session_update.session_id = session_id
|
481 |
-
session_update.created_at = existing_session.created_at
|
482 |
-
|
483 |
-
try:
|
484 |
-
updated_session = save_session(session_update)
|
485 |
-
logger.info(f"Updated session: {session_id}")
|
486 |
-
return updated_session
|
487 |
-
except Exception as e:
|
488 |
-
logger.error(f"Error updating session: {e}")
|
489 |
-
raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}")
|
490 |
-
|
491 |
-
# Conversation History Endpoints
|
492 |
-
@app.post("/conversations/add", response_model=str)
|
493 |
-
async def add_conversation_entry(entry: ConversationEntry):
|
494 |
-
"""Add a new entry to a conversation."""
|
495 |
-
try:
|
496 |
-
# Verify session exists
|
497 |
-
session = get_session(entry.session_id)
|
498 |
-
if not session:
|
499 |
-
raise HTTPException(status_code=404, detail=f"Session not found: {entry.session_id}")
|
500 |
-
|
501 |
-
entry_id = save_conversation_entry(entry)
|
502 |
-
logger.info(f"Added conversation entry: {entry_id}")
|
503 |
-
return entry_id
|
504 |
-
except HTTPException:
|
505 |
-
raise
|
506 |
-
except Exception as e:
|
507 |
-
logger.error(f"Error adding conversation entry: {e}")
|
508 |
-
raise HTTPException(status_code=500, detail=f"Error adding conversation entry: {str(e)}")
|
509 |
-
|
510 |
-
@app.get("/conversations/{session_id}", response_model=List[ConversationEntry])
|
511 |
-
async def get_conversation(session_id: str):
|
512 |
-
"""Get conversation history for a session."""
|
513 |
-
try:
|
514 |
-
# Verify session exists
|
515 |
-
session = get_session(session_id)
|
516 |
-
if not session:
|
517 |
-
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
|
518 |
-
|
519 |
-
entries = get_conversation_history(session_id)
|
520 |
-
return entries
|
521 |
-
except HTTPException:
|
522 |
-
raise
|
523 |
-
except Exception as e:
|
524 |
-
logger.error(f"Error retrieving conversation history: {e}")
|
525 |
-
raise HTTPException(status_code=500, detail=f"Error retrieving conversation history: {str(e)}")
|
526 |
-
|
527 |
-
# API Endpoints
|
528 |
-
@app.post("/suggest")
|
529 |
-
async def get_suggestion(context: PatientContext):
|
530 |
-
logger.info(f"Received suggestion request for context: {context.context}")
|
531 |
-
prediction = predict_response_type(context.context)
|
532 |
-
suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
|
533 |
-
suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
|
534 |
-
|
535 |
return {
|
536 |
-
"
|
537 |
-
"
|
538 |
-
"
|
539 |
-
"
|
540 |
-
"rag_suggestion": suggestion_rag["suggested_response"],
|
541 |
-
"rag_risk_level": suggestion_rag["risk_level"],
|
542 |
-
"direct_suggestion": suggestion_direct["suggested_response"],
|
543 |
-
"direct_risk_level": suggestion_direct["risk_level"]
|
544 |
}
|
545 |
|
546 |
-
|
547 |
-
async def get_session_suggestion(request: dict):
|
548 |
-
"""Get suggestion within a session context, with enhanced crisis detection based on session keywords."""
|
549 |
-
try:
|
550 |
-
session_id = request.get("session_id")
|
551 |
-
if not session_id:
|
552 |
-
raise HTTPException(status_code=400, detail="session_id is required")
|
553 |
-
|
554 |
-
context = request.get("context")
|
555 |
-
if not context:
|
556 |
-
raise HTTPException(status_code=400, detail="context is required")
|
557 |
-
|
558 |
-
# Get session for custom crisis keywords
|
559 |
-
session = get_session(session_id)
|
560 |
-
if not session:
|
561 |
-
raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
|
562 |
-
|
563 |
-
# Get conversation history for context
|
564 |
-
conversation_history = get_conversation_history(session_id)
|
565 |
-
|
566 |
-
# Regular prediction
|
567 |
-
prediction = predict_response_type(context)
|
568 |
-
crisis_flag = prediction["crisis_flag"]
|
569 |
-
|
570 |
-
# Enhanced crisis detection with custom keywords
|
571 |
-
if not crisis_flag and session.crisis_keywords:
|
572 |
-
for keyword in session.crisis_keywords:
|
573 |
-
if keyword.lower() in context.lower():
|
574 |
-
crisis_flag = True
|
575 |
-
logger.info(f"Crisis flag triggered by custom keyword: {keyword}")
|
576 |
-
break
|
577 |
-
|
578 |
-
# Generate suggestions
|
579 |
-
suggestion_rag = generate_suggestion_rag(context, prediction["response_type"], crisis_flag)
|
580 |
-
suggestion_direct = generate_suggestion_direct(context, prediction["response_type"], crisis_flag)
|
581 |
-
|
582 |
-
# Create response
|
583 |
-
response = {
|
584 |
-
"context": context,
|
585 |
-
"response_type": prediction["response_type"],
|
586 |
-
"crisis_flag": crisis_flag,
|
587 |
-
"confidence": prediction["confidence"],
|
588 |
-
"rag_suggestion": suggestion_rag["suggested_response"],
|
589 |
-
"rag_risk_level": suggestion_rag["risk_level"],
|
590 |
-
"direct_suggestion": suggestion_direct["suggested_response"],
|
591 |
-
"direct_risk_level": suggestion_direct["risk_level"],
|
592 |
-
"session_id": session_id
|
593 |
-
}
|
594 |
-
|
595 |
-
# Save the conversation entry
|
596 |
-
entry = ConversationEntry(
|
597 |
-
session_id=session_id,
|
598 |
-
message=context,
|
599 |
-
sender="patient",
|
600 |
-
suggested_response=suggestion_rag["suggested_response"],
|
601 |
-
response_type=prediction["response_type"],
|
602 |
-
crisis_flag=crisis_flag,
|
603 |
-
risk_level=suggestion_rag["risk_level"]
|
604 |
-
)
|
605 |
-
save_conversation_entry(entry)
|
606 |
-
|
607 |
-
return response
|
608 |
-
except HTTPException:
|
609 |
-
raise
|
610 |
-
except Exception as e:
|
611 |
-
logger.error(f"Error getting session suggestion: {e}")
|
612 |
-
raise HTTPException(status_code=500, detail=f"Error getting session suggestion: {str(e)}")
|
613 |
-
|
614 |
-
# Feedback Endpoints
|
615 |
-
@app.post("/feedback")
|
616 |
-
async def add_feedback(feedback: FeedbackData):
|
617 |
-
"""Add feedback about a suggestion's effectiveness."""
|
618 |
-
try:
|
619 |
-
feedback_id = save_feedback(feedback)
|
620 |
-
logger.info(f"Added feedback: {feedback_id}")
|
621 |
-
return {"feedback_id": feedback_id}
|
622 |
-
except Exception as e:
|
623 |
-
logger.error(f"Error adding feedback: {e}")
|
624 |
-
raise HTTPException(status_code=500, detail=f"Error adding feedback: {str(e)}")
|
625 |
-
|
626 |
-
# Tone & Cultural Sensitivity Analysis
|
627 |
-
@traceable(run_type="chain", name="Cultural Sensitivity Analysis")
|
628 |
-
def analyze_cultural_sensitivity(text: str, cultural_context: Optional[str] = None):
|
629 |
-
"""Analyze text for cultural appropriateness and sensitivity."""
|
630 |
-
prompt_template = ChatPromptTemplate.from_template(
|
631 |
-
"""
|
632 |
-
You are a cultural sensitivity expert. Analyze the following text for cultural appropriateness:
|
633 |
-
|
634 |
-
Text: {text}
|
635 |
-
|
636 |
-
Cultural Context: {cultural_context}
|
637 |
-
|
638 |
-
Provide an analysis of:
|
639 |
-
1. Cultural appropriateness
|
640 |
-
2. Potential bias or insensitivity
|
641 |
-
3. Suggestions for improvement
|
642 |
-
|
643 |
-
Output in the following format:
|
644 |
-
```json
|
645 |
-
{{
|
646 |
-
"cultural_appropriateness_score": 0-10,
|
647 |
-
"issues_detected": ["issue1", "issue2"],
|
648 |
-
"suggestions": ["suggestion1", "suggestion2"],
|
649 |
-
"explanation": "Brief explanation of analysis"
|
650 |
-
}}
|
651 |
-
```
|
652 |
-
"""
|
653 |
-
)
|
654 |
-
|
655 |
-
analysis_chain = (
|
656 |
-
{
|
657 |
-
"text": RunnablePassthrough(),
|
658 |
-
"cultural_context": lambda x: cultural_context if cultural_context else "General"
|
659 |
-
}
|
660 |
-
| prompt_template
|
661 |
-
| llm
|
662 |
-
)
|
663 |
-
|
664 |
-
try:
|
665 |
-
response = analysis_chain.invoke(text)
|
666 |
-
return eval(response.content.strip("```json\n").strip("\n```"))
|
667 |
-
except Exception as e:
|
668 |
-
logger.error(f"Error analyzing cultural sensitivity: {e}")
|
669 |
-
raise HTTPException(status_code=500, detail=f"Error analyzing cultural sensitivity: {str(e)}")
|
670 |
-
|
671 |
-
@traceable(run_type="chain", name="Age Appropriate Analysis")
|
672 |
-
def analyze_age_appropriateness(text: str, age: Optional[int] = None):
|
673 |
-
"""Analyze text for age-appropriate language."""
|
674 |
-
prompt_template = ChatPromptTemplate.from_template(
|
675 |
-
"""
|
676 |
-
You are an expert in age-appropriate communication. Analyze the following text for age appropriateness:
|
677 |
-
|
678 |
-
Text: {text}
|
679 |
-
|
680 |
-
Target Age: {age}
|
681 |
-
|
682 |
-
Provide an analysis of:
|
683 |
-
1. Age appropriateness
|
684 |
-
2. Complexity level
|
685 |
-
3. Suggestions for improvement
|
686 |
-
|
687 |
-
Output in the following format:
|
688 |
-
```json
|
689 |
-
{{
|
690 |
-
"age_appropriateness_score": 0-10,
|
691 |
-
"complexity_level": "Simple/Moderate/Complex",
|
692 |
-
"issues_detected": ["issue1", "issue2"],
|
693 |
-
"suggestions": ["suggestion1", "suggestion2"],
|
694 |
-
"explanation": "Brief explanation of analysis"
|
695 |
-
}}
|
696 |
-
```
|
697 |
-
"""
|
698 |
-
)
|
699 |
-
|
700 |
-
analysis_chain = (
|
701 |
-
{
|
702 |
-
"text": RunnablePassthrough(),
|
703 |
-
"age": lambda x: str(age) if age else "Adult"
|
704 |
-
}
|
705 |
-
| prompt_template
|
706 |
-
| llm
|
707 |
-
)
|
708 |
-
|
709 |
-
try:
|
710 |
-
response = analysis_chain.invoke(text)
|
711 |
-
return eval(response.content.strip("```json\n").strip("\n```"))
|
712 |
-
except Exception as e:
|
713 |
-
logger.error(f"Error analyzing age appropriateness: {e}")
|
714 |
-
raise HTTPException(status_code=500, detail=f"Error analyzing age appropriateness: {str(e)}")
|
715 |
-
|
716 |
-
@app.post("/analyze/sensitivity")
|
717 |
-
async def analyze_text_sensitivity(request: AnalysisRequest):
|
718 |
-
"""Analyze text for cultural sensitivity and age appropriateness."""
|
719 |
-
try:
|
720 |
-
cultural_analysis = analyze_cultural_sensitivity(request.text, request.cultural_context)
|
721 |
-
age_analysis = analyze_age_appropriateness(request.text, request.patient_age)
|
722 |
-
|
723 |
-
return {
|
724 |
-
"text": request.text,
|
725 |
-
"cultural_analysis": cultural_analysis,
|
726 |
-
"age_analysis": age_analysis
|
727 |
-
}
|
728 |
-
except Exception as e:
|
729 |
-
logger.error(f"Error analyzing text sensitivity: {e}")
|
730 |
-
raise HTTPException(status_code=500, detail=f"Error analyzing text sensitivity: {str(e)}")
|
731 |
-
|
732 |
-
# Guided Intervention Workflows
|
733 |
-
@traceable(run_type="chain", name="Generate Intervention")
|
734 |
-
def generate_intervention_workflow(issue: str, intervention_type: Optional[str] = None, background: Optional[Dict] = None):
|
735 |
-
"""Generate a structured intervention workflow for a specific issue."""
|
736 |
-
prompt_template = ChatPromptTemplate.from_template(
|
737 |
-
"""
|
738 |
-
You are an expert mental health counselor. Generate a structured intervention workflow for the following patient issue:
|
739 |
-
|
740 |
-
Patient Issue: {issue}
|
741 |
-
|
742 |
-
Intervention Type: {intervention_type}
|
743 |
-
Patient Background: {background}
|
744 |
-
|
745 |
-
Provide a step-by-step intervention plan based on evidence-based practices. Include:
|
746 |
-
1. Initial assessment questions
|
747 |
-
2. Specific techniques to apply
|
748 |
-
3. Homework or practice exercises
|
749 |
-
4. Follow-up guidance
|
750 |
-
|
751 |
-
Output in the following format:
|
752 |
-
```json
|
753 |
-
{{
|
754 |
-
"intervention_type": "CBT/DBT/ACT/Mindfulness/etc.",
|
755 |
-
"assessment_questions": ["question1", "question2", "question3"],
|
756 |
-
"techniques": [
|
757 |
-
{{
|
758 |
-
"name": "technique name",
|
759 |
-
"description": "brief description",
|
760 |
-
"instructions": "step-by-step instructions"
|
761 |
-
}}
|
762 |
-
],
|
763 |
-
"exercises": [
|
764 |
-
{{
|
765 |
-
"name": "exercise name",
|
766 |
-
"description": "brief description",
|
767 |
-
"instructions": "step-by-step instructions"
|
768 |
-
}}
|
769 |
-
],
|
770 |
-
"follow_up": ["follow-up step 1", "follow-up step 2"],
|
771 |
-
"resources": ["resource1", "resource2"]
|
772 |
-
}}
|
773 |
-
```
|
774 |
-
"""
|
775 |
-
)
|
776 |
-
|
777 |
-
intervention_chain = (
|
778 |
-
{
|
779 |
-
"issue": RunnablePassthrough(),
|
780 |
-
"intervention_type": lambda x: intervention_type if intervention_type else "Best fit",
|
781 |
-
"background": lambda x: str(background) if background else "Not provided"
|
782 |
-
}
|
783 |
-
| prompt_template
|
784 |
-
| llm
|
785 |
-
)
|
786 |
-
|
787 |
-
try:
|
788 |
-
response = intervention_chain.invoke(issue)
|
789 |
-
return eval(response.content.strip("```json\n").strip("\n```"))
|
790 |
-
except Exception as e:
|
791 |
-
logger.error(f"Error generating intervention workflow: {e}")
|
792 |
-
raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
|
793 |
-
|
794 |
-
@app.post("/interventions/generate")
|
795 |
-
async def get_intervention_workflow(request: InterventionRequest):
|
796 |
-
"""Get a structured intervention workflow for a specific patient issue."""
|
797 |
-
try:
|
798 |
-
intervention = generate_intervention_workflow(
|
799 |
-
request.patient_issue,
|
800 |
-
request.intervention_type,
|
801 |
-
request.patient_background
|
802 |
-
)
|
803 |
-
|
804 |
-
return {
|
805 |
-
"patient_issue": request.patient_issue,
|
806 |
-
"intervention": intervention
|
807 |
-
}
|
808 |
-
except Exception as e:
|
809 |
-
logger.error(f"Error generating intervention workflow: {e}")
|
810 |
-
raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
|
811 |
-
|
812 |
@app.get("/health")
|
813 |
async def health_check():
|
814 |
-
|
815 |
-
return {"status": "healthy", "message": "All models and vector store loaded successfully"}
|
816 |
-
logger.error("Health check failed: One or more components not loaded")
|
817 |
-
raise HTTPException(status_code=500, detail="One or more components failed to load")
|
818 |
|
|
|
819 |
@app.get("/metadata")
|
820 |
async def get_metadata():
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
if not profile.created_at:
|
835 |
-
profile.created_at = datetime.now()
|
836 |
-
|
837 |
-
profile.updated_at = datetime.now()
|
838 |
-
|
839 |
-
file_path = os.path.join(DATA_DIR, "users", f"{profile.user_id}.json")
|
840 |
-
with open(file_path, "w") as f:
|
841 |
-
# Convert datetime to string for JSON serialization
|
842 |
-
profile_dict = profile.dict()
|
843 |
-
for key in ["created_at", "updated_at"]:
|
844 |
-
if profile_dict[key]:
|
845 |
-
profile_dict[key] = profile_dict[key].isoformat()
|
846 |
-
f.write(json.dumps(profile_dict, indent=2))
|
847 |
-
|
848 |
-
return profile
|
849 |
-
|
850 |
-
def get_user_profile(user_id: str) -> Optional[UserProfile]:
|
851 |
-
file_path = os.path.join(DATA_DIR, "users", f"{user_id}.json")
|
852 |
-
if not os.path.exists(file_path):
|
853 |
-
return None
|
854 |
-
|
855 |
-
with open(file_path, "r") as f:
|
856 |
-
data = json.loads(f.read())
|
857 |
-
# Convert string dates back to datetime
|
858 |
-
for key in ["created_at", "updated_at"]:
|
859 |
-
if data[key]:
|
860 |
-
data[key] = datetime.fromisoformat(data[key])
|
861 |
-
return UserProfile(**data)
|
862 |
-
|
863 |
-
def save_session(session: SessionData):
|
864 |
-
if not session.session_id:
|
865 |
-
session.session_id = str(uuid4())
|
866 |
-
|
867 |
-
if not session.created_at:
|
868 |
-
session.created_at = datetime.now()
|
869 |
-
|
870 |
-
session.updated_at = datetime.now()
|
871 |
-
|
872 |
-
file_path = os.path.join(DATA_DIR, "sessions", f"{session.session_id}.json")
|
873 |
-
with open(file_path, "w") as f:
|
874 |
-
# Convert datetime to string for JSON serialization
|
875 |
-
session_dict = session.dict()
|
876 |
-
for key in ["created_at", "updated_at"]:
|
877 |
-
if session_dict[key]:
|
878 |
-
session_dict[key] = session_dict[key].isoformat()
|
879 |
-
f.write(json.dumps(session_dict, indent=2))
|
880 |
-
|
881 |
-
return session
|
882 |
-
|
883 |
-
def get_session(session_id: str) -> Optional[SessionData]:
|
884 |
-
file_path = os.path.join(DATA_DIR, "sessions", f"{session_id}.json")
|
885 |
-
if not os.path.exists(file_path):
|
886 |
-
return None
|
887 |
-
|
888 |
-
with open(file_path, "r") as f:
|
889 |
-
data = json.loads(f.read())
|
890 |
-
# Convert string dates back to datetime
|
891 |
-
for key in ["created_at", "updated_at"]:
|
892 |
-
if data[key]:
|
893 |
-
data[key] = datetime.fromisoformat(data[key])
|
894 |
-
return SessionData(**data)
|
895 |
-
|
896 |
-
def get_user_sessions(counselor_id: str) -> List[SessionData]:
|
897 |
-
sessions = []
|
898 |
-
sessions_dir = os.path.join(DATA_DIR, "sessions")
|
899 |
-
for filename in os.listdir(sessions_dir):
|
900 |
-
if not filename.endswith(".json"):
|
901 |
-
continue
|
902 |
-
|
903 |
-
file_path = os.path.join(sessions_dir, filename)
|
904 |
-
with open(file_path, "r") as f:
|
905 |
-
data = json.loads(f.read())
|
906 |
-
if data["counselor_id"] == counselor_id:
|
907 |
-
for key in ["created_at", "updated_at"]:
|
908 |
-
if data[key]:
|
909 |
-
data[key] = datetime.fromisoformat(data[key])
|
910 |
-
sessions.append(SessionData(**data))
|
911 |
-
|
912 |
-
return sessions
|
913 |
-
|
914 |
-
def save_conversation_entry(entry: ConversationEntry):
|
915 |
-
conversation_dir = os.path.join(DATA_DIR, "conversations", entry.session_id)
|
916 |
-
os.makedirs(conversation_dir, exist_ok=True)
|
917 |
-
|
918 |
-
if not entry.timestamp:
|
919 |
-
entry.timestamp = datetime.now()
|
920 |
-
|
921 |
-
entry_id = str(uuid4())
|
922 |
-
file_path = os.path.join(conversation_dir, f"{entry_id}.json")
|
923 |
-
|
924 |
-
with open(file_path, "w") as f:
|
925 |
-
# Convert datetime to string for JSON serialization
|
926 |
-
entry_dict = entry.dict()
|
927 |
-
entry_dict["entry_id"] = entry_id
|
928 |
-
if entry_dict["timestamp"]:
|
929 |
-
entry_dict["timestamp"] = entry_dict["timestamp"].isoformat()
|
930 |
-
f.write(json.dumps(entry_dict, indent=2))
|
931 |
-
|
932 |
-
return entry_id
|
933 |
-
|
934 |
-
def get_conversation_history(session_id: str) -> List[ConversationEntry]:
|
935 |
-
conversation_dir = os.path.join(DATA_DIR, "conversations", session_id)
|
936 |
-
if not os.path.exists(conversation_dir):
|
937 |
-
return []
|
938 |
-
|
939 |
-
entries = []
|
940 |
-
for filename in os.listdir(conversation_dir):
|
941 |
-
if not filename.endswith(".json"):
|
942 |
-
continue
|
943 |
-
|
944 |
-
file_path = os.path.join(conversation_dir, filename)
|
945 |
-
with open(file_path, "r") as f:
|
946 |
-
data = json.loads(f.read())
|
947 |
-
if data["timestamp"]:
|
948 |
-
data["timestamp"] = datetime.fromisoformat(data["timestamp"])
|
949 |
-
entries.append(ConversationEntry(**data))
|
950 |
-
|
951 |
-
# Sort by timestamp
|
952 |
-
entries.sort(key=lambda x: x.timestamp)
|
953 |
-
return entries
|
954 |
-
|
955 |
-
def save_feedback(feedback: FeedbackData):
|
956 |
-
feedback_id = str(uuid4())
|
957 |
-
file_path = os.path.join(DATA_DIR, "feedback", f"{feedback_id}.json")
|
958 |
-
|
959 |
-
with open(file_path, "w") as f:
|
960 |
-
feedback_dict = feedback.dict()
|
961 |
-
feedback_dict["feedback_id"] = feedback_id
|
962 |
-
feedback_dict["timestamp"] = datetime.now().isoformat()
|
963 |
-
f.write(json.dumps(feedback_dict, indent=2))
|
964 |
-
|
965 |
-
return feedback_id
|
966 |
-
|
967 |
-
# Multi-modal Input Support
|
968 |
-
@app.post("/input/process")
|
969 |
-
async def process_multimodal_input(input_data: MultiModalInput):
|
970 |
-
"""Process multi-modal input (text, audio, video)."""
|
971 |
-
try:
|
972 |
-
if input_data.input_type not in ["text", "audio", "video"]:
|
973 |
-
raise HTTPException(status_code=400, detail=f"Unsupported input type: {input_data.input_type}")
|
974 |
-
|
975 |
-
# For now, handle text directly and simulate processing for audio/video
|
976 |
-
if input_data.input_type == "text":
|
977 |
-
# Process text normally
|
978 |
-
prediction = predict_response_type(input_data.content)
|
979 |
-
|
980 |
-
return {
|
981 |
-
"input_type": "text",
|
982 |
-
"processed_content": input_data.content,
|
983 |
-
"analysis": {
|
984 |
-
"response_type": prediction["response_type"],
|
985 |
-
"crisis_flag": prediction["crisis_flag"],
|
986 |
-
"confidence": prediction["confidence"]
|
987 |
-
}
|
988 |
-
}
|
989 |
-
elif input_data.input_type == "audio":
|
990 |
-
# Simulate audio transcription and emotion detection
|
991 |
-
# In a production system, this would use a speech-to-text API and emotion analysis
|
992 |
-
prompt_template = ChatPromptTemplate.from_template(
|
993 |
-
"""
|
994 |
-
Simulate audio processing for this description: {content}
|
995 |
-
|
996 |
-
Generate a simulated transcription and emotion detection as if this were real audio.
|
997 |
-
|
998 |
-
Output in the following format:
|
999 |
-
```json
|
1000 |
-
{{
|
1001 |
-
"transcription": "Simulated transcription of the audio",
|
1002 |
-
"emotion_detected": "primary emotion",
|
1003 |
-
"secondary_emotions": ["emotion1", "emotion2"],
|
1004 |
-
"confidence": 0.85
|
1005 |
-
}}
|
1006 |
-
```
|
1007 |
-
"""
|
1008 |
-
)
|
1009 |
-
|
1010 |
-
process_chain = prompt_template | llm
|
1011 |
-
response = process_chain.invoke({"content": input_data.content})
|
1012 |
-
audio_result = eval(response.content.strip("```json\n").strip("\n```"))
|
1013 |
-
|
1014 |
-
# Now process the transcription
|
1015 |
-
prediction = predict_response_type(audio_result["transcription"])
|
1016 |
-
|
1017 |
-
return {
|
1018 |
-
"input_type": "audio",
|
1019 |
-
"processed_content": audio_result["transcription"],
|
1020 |
-
"emotion_analysis": {
|
1021 |
-
"primary_emotion": audio_result["emotion_detected"],
|
1022 |
-
"secondary_emotions": audio_result["secondary_emotions"],
|
1023 |
-
"confidence": audio_result["confidence"]
|
1024 |
-
},
|
1025 |
-
"analysis": {
|
1026 |
-
"response_type": prediction["response_type"],
|
1027 |
-
"crisis_flag": prediction["crisis_flag"],
|
1028 |
-
"confidence": prediction["confidence"]
|
1029 |
-
}
|
1030 |
-
}
|
1031 |
-
elif input_data.input_type == "video":
|
1032 |
-
# Simulate video analysis
|
1033 |
-
# In a production system, this would use video analytics API
|
1034 |
-
prompt_template = ChatPromptTemplate.from_template(
|
1035 |
-
"""
|
1036 |
-
Simulate video processing for this description: {content}
|
1037 |
-
|
1038 |
-
Generate a simulated analysis as if this were real video with facial expressions and body language.
|
1039 |
-
|
1040 |
-
Output in the following format:
|
1041 |
-
```json
|
1042 |
-
{{
|
1043 |
-
"transcription": "Simulated transcription of speech in the video",
|
1044 |
-
"facial_expressions": ["expression1", "expression2"],
|
1045 |
-
"body_language": ["posture observation", "gesture observation"],
|
1046 |
-
"primary_emotion": "primary emotion",
|
1047 |
-
"confidence": 0.80
|
1048 |
-
}}
|
1049 |
-
```
|
1050 |
-
"""
|
1051 |
-
)
|
1052 |
-
|
1053 |
-
process_chain = prompt_template | llm
|
1054 |
-
response = process_chain.invoke({"content": input_data.content})
|
1055 |
-
video_result = eval(response.content.strip("```json\n").strip("\n```"))
|
1056 |
-
|
1057 |
-
# Process the transcription
|
1058 |
-
prediction = predict_response_type(video_result["transcription"])
|
1059 |
-
|
1060 |
-
return {
|
1061 |
-
"input_type": "video",
|
1062 |
-
"processed_content": video_result["transcription"],
|
1063 |
-
"nonverbal_analysis": {
|
1064 |
-
"facial_expressions": video_result["facial_expressions"],
|
1065 |
-
"body_language": video_result["body_language"],
|
1066 |
-
"primary_emotion": video_result["primary_emotion"],
|
1067 |
-
"confidence": video_result["confidence"]
|
1068 |
-
},
|
1069 |
-
"analysis": {
|
1070 |
-
"response_type": prediction["response_type"],
|
1071 |
-
"crisis_flag": prediction["crisis_flag"],
|
1072 |
-
"confidence": prediction["confidence"]
|
1073 |
-
}
|
1074 |
-
}
|
1075 |
-
except Exception as e:
|
1076 |
-
logger.error(f"Error processing multimodal input: {e}")
|
1077 |
-
raise HTTPException(status_code=500, detail=f"Error processing multimodal input: {str(e)}")
|
1078 |
|
1079 |
-
#
|
1080 |
-
|
1081 |
-
|
1082 |
-
|
1083 |
-
|
1084 |
-
"""
|
1085 |
-
You are an expert mental health professional with extensive knowledge of therapeutic techniques. Based on the following patient context, suggest therapeutic techniques that would be appropriate:
|
1086 |
-
|
1087 |
-
Patient Context: {context}
|
1088 |
-
|
1089 |
-
Technique Type (if specified): {technique_type}
|
1090 |
-
|
1091 |
-
Suggest specific therapeutic techniques, exercises, or interventions that would be helpful for this patient. Include:
|
1092 |
-
1. Name of technique
|
1093 |
-
2. Brief description
|
1094 |
-
3. How to apply it in this specific case
|
1095 |
-
4. Expected benefits
|
1096 |
-
|
1097 |
-
Provide a range of options from different therapeutic approaches (CBT, DBT, ACT, mindfulness, motivational interviewing, etc.) unless a specific technique type was requested.
|
1098 |
-
|
1099 |
-
Output in the following format:
|
1100 |
-
```json
|
1101 |
-
{{
|
1102 |
-
"primary_approach": "The most appropriate therapeutic approach",
|
1103 |
-
"techniques": [
|
1104 |
-
{{
|
1105 |
-
"name": "Technique name",
|
1106 |
-
"approach": "CBT/DBT/ACT/etc.",
|
1107 |
-
"description": "Brief description",
|
1108 |
-
"application": "How to apply to this specific case",
|
1109 |
-
"benefits": "Expected benefits"
|
1110 |
-
}}
|
1111 |
-
],
|
1112 |
-
"rationale": "Brief explanation of why these techniques were selected"
|
1113 |
-
}}
|
1114 |
-
```
|
1115 |
-
"""
|
1116 |
-
)
|
1117 |
|
1118 |
-
|
1119 |
-
|
1120 |
-
"context": RunnablePassthrough(),
|
1121 |
-
"technique_type": lambda x: technique_type if technique_type else "Any appropriate"
|
1122 |
-
}
|
1123 |
-
| prompt_template
|
1124 |
-
| llm
|
1125 |
-
)
|
1126 |
|
1127 |
-
|
1128 |
-
|
1129 |
-
|
1130 |
-
except Exception as e:
|
1131 |
-
logger.error(f"Error suggesting therapeutic techniques: {e}")
|
1132 |
-
raise HTTPException(status_code=500, detail=f"Error suggesting therapeutic techniques: {str(e)}")
|
1133 |
-
|
1134 |
-
@app.post("/techniques/suggest")
|
1135 |
-
async def get_therapeutic_techniques(request: dict):
|
1136 |
-
"""Get suggested therapeutic techniques for a patient context."""
|
1137 |
-
try:
|
1138 |
-
context = request.get("context")
|
1139 |
-
if not context:
|
1140 |
-
raise HTTPException(status_code=400, detail="context is required")
|
1141 |
-
|
1142 |
-
technique_type = request.get("technique_type")
|
1143 |
-
|
1144 |
-
techniques = suggest_therapeutic_techniques(context, technique_type)
|
1145 |
-
|
1146 |
return {
|
1147 |
-
"
|
1148 |
-
"
|
|
|
1149 |
}
|
1150 |
-
|
1151 |
-
|
1152 |
-
|
1153 |
-
|
1154 |
-
|
1155 |
-
@app.post("/suggest/with_confidence")
|
1156 |
-
async def get_suggestion_with_confidence(context: PatientContext):
|
1157 |
-
"""Get suggestion with detailed confidence indicators and uncertainty flags."""
|
1158 |
-
try:
|
1159 |
-
# Get standard prediction
|
1160 |
-
prediction = predict_response_type(context.context)
|
1161 |
-
|
1162 |
-
# Set confidence thresholds
|
1163 |
-
high_confidence = 0.8
|
1164 |
-
medium_confidence = 0.6
|
1165 |
-
|
1166 |
-
# Determine confidence level
|
1167 |
-
confidence_value = prediction["confidence"]
|
1168 |
-
if confidence_value >= high_confidence:
|
1169 |
-
confidence_level = "High"
|
1170 |
-
elif confidence_value >= medium_confidence:
|
1171 |
-
confidence_level = "Medium"
|
1172 |
-
else:
|
1173 |
-
confidence_level = "Low"
|
1174 |
-
|
1175 |
-
# Analyze for potential biases
|
1176 |
-
bias_prompt = ChatPromptTemplate.from_template(
|
1177 |
-
"""
|
1178 |
-
You are an AI ethics expert. Analyze the following patient context and proposed response type for potential biases:
|
1179 |
-
|
1180 |
-
Patient Context: {context}
|
1181 |
-
Predicted Response Type: {response_type}
|
1182 |
-
|
1183 |
-
Identify any potential biases in interpretation or response. Consider gender, cultural, socioeconomic, and other potential biases.
|
1184 |
-
|
1185 |
-
Output in the following format:
|
1186 |
-
```json
|
1187 |
-
{{
|
1188 |
-
"bias_detected": true/false,
|
1189 |
-
"bias_types": ["bias type 1", "bias type 2"],
|
1190 |
-
"explanation": "Brief explanation of potential biases"
|
1191 |
-
}}
|
1192 |
-
```
|
1193 |
-
"""
|
1194 |
-
)
|
1195 |
-
|
1196 |
-
bias_chain = (
|
1197 |
-
{
|
1198 |
-
"context": lambda x: context.context,
|
1199 |
-
"response_type": lambda x: prediction["response_type"]
|
1200 |
-
}
|
1201 |
-
| bias_prompt
|
1202 |
-
| llm
|
1203 |
-
)
|
1204 |
-
|
1205 |
-
bias_analysis = eval(bias_chain.invoke({}).content.strip("```json\n").strip("\n```"))
|
1206 |
-
|
1207 |
-
# Generate suggestions
|
1208 |
-
suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
|
1209 |
-
suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
|
1210 |
-
|
1211 |
-
return {
|
1212 |
-
"context": context.context,
|
1213 |
-
"response_type": prediction["response_type"],
|
1214 |
-
"crisis_flag": prediction["crisis_flag"],
|
1215 |
-
"confidence": {
|
1216 |
-
"value": prediction["confidence"],
|
1217 |
-
"level": confidence_level,
|
1218 |
-
"uncertainty_flag": confidence_level == "Low"
|
1219 |
-
},
|
1220 |
-
"bias_analysis": bias_analysis,
|
1221 |
-
"rag_suggestion": suggestion_rag["suggested_response"],
|
1222 |
-
"rag_risk_level": suggestion_rag["risk_level"],
|
1223 |
-
"direct_suggestion": suggestion_direct["suggested_response"],
|
1224 |
-
"direct_risk_level": suggestion_direct["risk_level"],
|
1225 |
-
"attribution": {
|
1226 |
-
"ai_generated": True,
|
1227 |
-
"model_version": "Mental Health Counselor API v2.0",
|
1228 |
-
"human_reviewed": False
|
1229 |
-
}
|
1230 |
-
}
|
1231 |
-
except Exception as e:
|
1232 |
-
logger.error(f"Error getting suggestion with confidence: {e}")
|
1233 |
-
raise HTTPException(status_code=500, detail=f"Error getting suggestion with confidence: {str(e)}")
|
1234 |
-
|
1235 |
-
# Text to Speech with Eleven Labs API
|
1236 |
-
@app.post("/api/text-to-speech")
|
1237 |
-
async def text_to_speech(request: dict):
|
1238 |
-
"""Convert text to speech using Eleven Labs API."""
|
1239 |
-
try:
|
1240 |
-
text = request.get("text")
|
1241 |
-
voice_id = request.get("voice_id", "pNInz6obpgDQGcFmaJgB") # Default to "Adam" voice
|
1242 |
-
|
1243 |
-
if not text:
|
1244 |
-
raise HTTPException(status_code=400, detail="Text is required")
|
1245 |
-
|
1246 |
-
# Get API key from environment
|
1247 |
-
api_key = os.getenv("ELEVEN_API_KEY")
|
1248 |
-
if not api_key:
|
1249 |
-
raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
|
1250 |
-
|
1251 |
-
# Prepare the request to Eleven Labs
|
1252 |
-
url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
|
1253 |
-
|
1254 |
-
headers = {
|
1255 |
-
"Accept": "audio/mpeg",
|
1256 |
-
"Content-Type": "application/json",
|
1257 |
-
"xi-api-key": api_key
|
1258 |
-
}
|
1259 |
-
|
1260 |
-
payload = {
|
1261 |
-
"text": text,
|
1262 |
-
"model_id": "eleven_multilingual_v2",
|
1263 |
-
"voice_settings": {
|
1264 |
-
"stability": 0.5,
|
1265 |
-
"similarity_boost": 0.75
|
1266 |
-
}
|
1267 |
-
}
|
1268 |
-
|
1269 |
-
# Make the request to Eleven Labs
|
1270 |
-
response = requests.post(url, json=payload, headers=headers)
|
1271 |
-
|
1272 |
-
if response.status_code != 200:
|
1273 |
-
logger.error(f"Eleven Labs API error: {response.text}")
|
1274 |
-
raise HTTPException(status_code=response.status_code,
|
1275 |
-
detail=f"Eleven Labs API error: {response.text}")
|
1276 |
-
|
1277 |
-
# Return audio as streaming response
|
1278 |
-
return StreamingResponse(
|
1279 |
-
BytesIO(response.content),
|
1280 |
-
media_type="audio/mpeg"
|
1281 |
-
)
|
1282 |
-
|
1283 |
-
except Exception as e:
|
1284 |
-
logger.error(f"Error in text-to-speech: {str(e)}")
|
1285 |
-
if not isinstance(e, HTTPException):
|
1286 |
-
raise HTTPException(status_code=500, detail=f"Text-to-speech error: {str(e)}")
|
1287 |
-
raise e
|
1288 |
-
|
1289 |
-
# Multimedia file processing (speech to text)
|
1290 |
-
@app.post("/api/input/process")
|
1291 |
-
async def process_audio_input(
|
1292 |
-
audio: UploadFile = File(...),
|
1293 |
-
session_id: str = Form(...)
|
1294 |
-
):
|
1295 |
-
"""Process audio input for speech-to-text using Eleven Labs."""
|
1296 |
-
try:
|
1297 |
-
# Get API key from environment
|
1298 |
-
api_key = os.getenv("ELEVEN_API_KEY")
|
1299 |
-
if not api_key:
|
1300 |
-
raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
|
1301 |
-
|
1302 |
-
# Read the audio file content
|
1303 |
-
audio_content = await audio.read()
|
1304 |
-
|
1305 |
-
# Call Eleven Labs Speech-to-Text API
|
1306 |
-
url = "https://api.elevenlabs.io/v1/speech-to-text"
|
1307 |
-
|
1308 |
-
headers = {
|
1309 |
-
"xi-api-key": api_key
|
1310 |
-
}
|
1311 |
-
|
1312 |
-
# Create form data with the audio file
|
1313 |
-
files = {
|
1314 |
-
'audio': ('audio.webm', audio_content, 'audio/webm')
|
1315 |
-
}
|
1316 |
-
|
1317 |
-
data = {
|
1318 |
-
'model_id': 'whisper-1' # Using Whisper model
|
1319 |
-
}
|
1320 |
-
|
1321 |
-
# Make the request to Eleven Labs
|
1322 |
-
response = requests.post(url, headers=headers, files=files, data=data)
|
1323 |
-
|
1324 |
-
if response.status_code != 200:
|
1325 |
-
logger.error(f"Eleven Labs API error: {response.text}")
|
1326 |
-
raise HTTPException(status_code=response.status_code,
|
1327 |
-
detail=f"Eleven Labs API error: {response.text}")
|
1328 |
-
|
1329 |
-
result = response.json()
|
1330 |
-
|
1331 |
-
# Extract the transcribed text
|
1332 |
-
text = result.get('text', '')
|
1333 |
-
|
1334 |
-
# Return the transcribed text
|
1335 |
return {
|
1336 |
-
"
|
1337 |
-
"
|
|
|
1338 |
}
|
1339 |
-
|
1340 |
-
except Exception as e:
|
1341 |
-
logger.error(f"Error processing audio: {str(e)}")
|
1342 |
-
if not isinstance(e, HTTPException):
|
1343 |
-
raise HTTPException(status_code=500, detail=f"Audio processing error: {str(e)}")
|
1344 |
-
raise e
|
1345 |
-
|
1346 |
-
# Add a custom encoder for bytes objects to prevent UTF-8 decode errors
|
1347 |
-
def custom_encoder(obj):
|
1348 |
-
if isinstance(obj, bytes):
|
1349 |
-
try:
|
1350 |
-
return obj.decode('utf-8')
|
1351 |
-
except UnicodeDecodeError:
|
1352 |
-
return base64.b64encode(obj).decode('ascii')
|
1353 |
-
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
1354 |
-
|
1355 |
-
# Override the jsonable_encoder function to handle bytes properly
|
1356 |
-
from fastapi.encoders import jsonable_encoder as original_jsonable_encoder
|
1357 |
-
|
1358 |
-
def safe_jsonable_encoder(*args, **kwargs):
|
1359 |
-
try:
|
1360 |
-
return original_jsonable_encoder(*args, **kwargs)
|
1361 |
-
except UnicodeDecodeError:
|
1362 |
-
# If the standard encoder fails with a decode error,
|
1363 |
-
# ensure all bytes are properly handled
|
1364 |
-
if args and isinstance(args[0], bytes):
|
1365 |
-
return custom_encoder(args[0])
|
1366 |
-
raise
|
1367 |
-
|
1368 |
-
# Monkey patch the jsonable_encoder in FastAPI
|
1369 |
-
import fastapi.encoders
|
1370 |
-
fastapi.encoders.jsonable_encoder = safe_jsonable_encoder
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
|
|
2 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import os
|
4 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
5 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# Set up logging
|
8 |
logging.basicConfig(level=logging.INFO)
|
|
|
11 |
# Load environment variables
|
12 |
load_dotenv()
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
# Initialize FastAPI app
|
15 |
app = FastAPI(title="Mental Health Counselor API")
|
16 |
|
17 |
+
# Initialize global storage
|
18 |
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
19 |
os.makedirs(DATA_DIR, exist_ok=True)
|
20 |
os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
|
|
|
22 |
os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
|
23 |
os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
|
24 |
|
25 |
+
# Simple health check route
|
26 |
+
@app.get("/")
|
27 |
+
async def root():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return {
|
29 |
+
"status": "ok",
|
30 |
+
"message": "Mental Health Counselor API is running",
|
31 |
+
"api_version": "1.0.0",
|
32 |
+
"backend_info": "FastAPI on Hugging Face Spaces"
|
|
|
|
|
|
|
|
|
33 |
}
|
34 |
|
35 |
+
# Health check endpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
@app.get("/health")
|
37 |
async def health_check():
|
38 |
+
return {"status": "ok", "message": "Mental Health Counselor API is running"}
|
|
|
|
|
|
|
39 |
|
40 |
+
# Metadata endpoint
|
41 |
@app.get("/metadata")
|
42 |
async def get_metadata():
|
43 |
+
return {
|
44 |
+
"api_version": "1.0.0",
|
45 |
+
"endpoints": [
|
46 |
+
"/",
|
47 |
+
"/health",
|
48 |
+
"/metadata"
|
49 |
+
],
|
50 |
+
"provider": "Mental Health Counselor API on Hugging Face Spaces",
|
51 |
+
"deployment_type": "Hugging Face Spaces Docker",
|
52 |
+
"description": "This API provides functionality for a mental health counseling application.",
|
53 |
+
"frontend": "Deployed separately on Vercel"
|
54 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
# Try to import the full API if available
|
57 |
+
try:
|
58 |
+
import api_mental_health
|
59 |
+
# If the import succeeds, mount all routes from the API
|
60 |
+
logger.info("Successfully imported full API module")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# Include all routes from the API
|
63 |
+
app.include_router(api_mental_health.app, prefix="")
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# Add a status endpoint for the full API
|
66 |
+
@app.get("/full-api-status")
|
67 |
+
async def full_api_status():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
return {
|
69 |
+
"status": "active",
|
70 |
+
"message": "Full API module is active and all routes are mounted",
|
71 |
+
"routes_count": len(api_mental_health.app.routes)
|
72 |
}
|
73 |
+
except ImportError as e:
|
74 |
+
logger.warning(f"Could not import full API module: {e}")
|
75 |
+
|
76 |
+
@app.get("/full-api-status")
|
77 |
+
async def full_api_status():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return {
|
79 |
+
"status": "unavailable",
|
80 |
+
"message": "Full API module could not be imported",
|
81 |
+
"error": str(e)
|
82 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,38 +1,37 @@
|
|
1 |
# Core dependencies for data processing and ML
|
2 |
-
pandas
|
3 |
-
numpy
|
4 |
-
scikit-learn
|
5 |
-
joblib
|
6 |
|
7 |
# NLP and sentiment analysis
|
8 |
-
nltk
|
9 |
-
vaderSentiment
|
10 |
|
11 |
# Dataset downloading
|
12 |
-
kagglehub
|
13 |
|
14 |
# Vector database and embeddings
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
langchain
|
19 |
-
langchain-
|
20 |
-
|
21 |
|
22 |
# API and tracing
|
23 |
-
fastapi
|
24 |
-
uvicorn
|
25 |
-
pydantic
|
26 |
-
langsmith
|
27 |
-
python-dotenv
|
28 |
-
|
29 |
-
lightgbm>=4.0.0
|
30 |
|
31 |
# New dependencies for additional features
|
32 |
-
python-multipart
|
33 |
fastapi-cors # For CORS support
|
34 |
-
aiofiles
|
35 |
-
jinja2
|
36 |
python-jose[cryptography] # For JWT tokens (authentication)
|
37 |
passlib[bcrypt] # For password hashing
|
38 |
pydub # For audio processing
|
|
|
1 |
# Core dependencies for data processing and ML
|
2 |
+
pandas
|
3 |
+
numpy
|
4 |
+
scikit-learn # Used for ML models
|
5 |
+
joblib
|
6 |
|
7 |
# NLP and sentiment analysis
|
8 |
+
nltk
|
9 |
+
vaderSentiment
|
10 |
|
11 |
# Dataset downloading
|
12 |
+
kagglehub
|
13 |
|
14 |
# Vector database and embeddings
|
15 |
+
chromadb
|
16 |
+
openai
|
17 |
+
langchain
|
18 |
+
langchain-openai
|
19 |
+
langchain-chroma
|
20 |
+
httpx # Required for API calls
|
21 |
|
22 |
# API and tracing
|
23 |
+
fastapi
|
24 |
+
uvicorn
|
25 |
+
pydantic
|
26 |
+
langsmith
|
27 |
+
python-dotenv
|
28 |
+
lightgbm
|
|
|
29 |
|
30 |
# New dependencies for additional features
|
31 |
+
python-multipart # For file uploads
|
32 |
fastapi-cors # For CORS support
|
33 |
+
aiofiles # For async file operations
|
34 |
+
jinja2 # For template rendering
|
35 |
python-jose[cryptography] # For JWT tokens (authentication)
|
36 |
passlib[bcrypt] # For password hashing
|
37 |
pydub # For audio processing
|