Said Lfagrouche commited on
Commit
e151d55
·
1 Parent(s): 8242927

Prepare for Hugging Face Spaces deployment with simplified configuration

Browse files
Files changed (3) hide show
  1. Dockerfile +5 -4
  2. app.py +1334 -42
  3. requirements.txt +23 -22
Dockerfile CHANGED
@@ -2,17 +2,18 @@ FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install git and git-lfs for downloading large files (if needed)
6
  RUN apt-get update && \
7
- apt-get install -y git git-lfs build-essential && \
8
  apt-get clean && \
9
  rm -rf /var/lib/apt/lists/*
10
 
11
  # Copy requirements file
12
  COPY requirements.txt .
13
 
14
- # Install dependencies
15
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
 
16
 
17
  # Download NLTK data
18
  RUN python -c "import nltk; nltk.download('punkt'); nltk.download('wordnet'); nltk.download('stopwords')"
 
2
 
3
  WORKDIR /app
4
 
5
+ # Install git, git-lfs, and build dependencies for native extensions
6
  RUN apt-get update && \
7
+ apt-get install -y git git-lfs build-essential cmake pkg-config libpq-dev gcc g++ && \
8
  apt-get clean && \
9
  rm -rf /var/lib/apt/lists/*
10
 
11
  # Copy requirements file
12
  COPY requirements.txt .
13
 
14
+ # Install dependencies with pip upgrade
15
+ RUN pip install --no-cache-dir --upgrade pip && \
16
+ pip install --no-cache-dir --upgrade -r requirements.txt
17
 
18
  # Download NLTK data
19
  RUN python -c "import nltk; nltk.download('punkt'); nltk.download('wordnet'); nltk.download('stopwords')"
app.py CHANGED
@@ -1,8 +1,34 @@
1
- from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  from dotenv import load_dotenv
 
 
 
5
  import logging
 
 
 
 
 
 
 
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
@@ -11,10 +37,15 @@ logger = logging.getLogger(__name__)
11
  # Load environment variables
12
  load_dotenv()
13
 
 
 
 
 
 
14
  # Initialize FastAPI app
15
  app = FastAPI(title="Mental Health Counselor API")
16
 
17
- # Initialize global storage
18
  DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
19
  os.makedirs(DATA_DIR, exist_ok=True)
20
  os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
@@ -22,57 +53,1318 @@ os.makedirs(os.path.join(DATA_DIR, "sessions"), exist_ok=True)
22
  os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
23
  os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
24
 
25
- # Simple health check route
26
- @app.get("/")
27
- async def root():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return {
29
- "status": "ok",
30
- "message": "Mental Health Counselor API is running",
31
- "api_version": "1.0.0",
32
- "backend_info": "FastAPI on Hugging Face Spaces"
33
  }
34
 
35
- # Health check endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @app.get("/health")
37
  async def health_check():
38
- return {"status": "ok", "message": "Mental Health Counselor API is running"}
 
 
 
39
 
40
- # Metadata endpoint
41
  @app.get("/metadata")
42
  async def get_metadata():
43
- return {
44
- "api_version": "1.0.0",
45
- "endpoints": [
46
- "/",
47
- "/health",
48
- "/metadata"
49
- ],
50
- "provider": "Mental Health Counselor API on Hugging Face Spaces",
51
- "deployment_type": "Hugging Face Spaces Docker",
52
- "description": "This API provides functionality for a mental health counseling application.",
53
- "frontend": "Deployed separately on Vercel"
54
- }
55
 
56
- # Try to import the full API if available
57
- try:
58
- import api_mental_health
59
- # If the import succeeds, try to add those routes
60
- logger.info("Successfully imported full API module")
61
 
62
- # Add a placeholder for full functionality
63
- @app.get("/full-api-status")
64
- async def full_api_status():
65
- return {
66
- "status": "imported",
67
- "message": "Full API module was imported successfully, but endpoints may require additional setup"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  }
69
- except ImportError as e:
70
- logger.warning(f"Could not import full API module: {e}")
 
71
 
72
- @app.get("/full-api-status")
73
- async def full_api_status():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  return {
75
- "status": "unavailable",
76
- "message": "Full API module could not be imported",
77
- "error": str(e)
78
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api_mental_health.py
2
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
3
  from pydantic import BaseModel
4
+ import pandas as pd
5
+ import numpy as np
6
+ import joblib
7
+ import re
8
+ import nltk
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.stem import WordNetLemmatizer
11
+ from nltk.corpus import stopwords
12
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
13
+ import chromadb
14
+ from chromadb.config import Settings
15
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
16
+ from langchain_chroma import Chroma
17
+ from openai import OpenAI
18
  import os
19
  from dotenv import load_dotenv
20
+ from langsmith import Client, traceable
21
+ from langchain_core.runnables import RunnablePassthrough
22
+ from langchain_core.prompts import ChatPromptTemplate
23
  import logging
24
+ from typing import List, Dict, Optional, Any, Union, Annotated
25
+ from datetime import datetime
26
+ from uuid import uuid4, UUID
27
+ import json
28
+ import requests
29
+ from fastapi.responses import StreamingResponse
30
+ from io import BytesIO
31
+ import base64
32
 
33
  # Set up logging
34
  logging.basicConfig(level=logging.INFO)
 
37
  # Load environment variables
38
  load_dotenv()
39
 
40
+ # Download required NLTK data
41
+ nltk.download('punkt')
42
+ nltk.download('wordnet')
43
+ nltk.download('stopwords')
44
+
45
  # Initialize FastAPI app
46
  app = FastAPI(title="Mental Health Counselor API")
47
 
48
+ # Initialize global storage (to be replaced with proper database)
49
  DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
50
  os.makedirs(DATA_DIR, exist_ok=True)
51
  os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
 
53
  os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
54
  os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
55
 
56
+ # Initialize components
57
+ STOPWORDS = set(stopwords.words("english"))
58
+ lemmatizer = WordNetLemmatizer()
59
+ analyzer = SentimentIntensityAnalyzer()
60
+ output_dir = "mental_health_model_artifacts"
61
+
62
+ # Global variables for models and vector store
63
+ response_clf = None
64
+ crisis_clf = None
65
+ vectorizer = None
66
+ le = None
67
+ selector = None
68
+ lda = None
69
+ vector_store = None
70
+ llm = None
71
+ openai_client = None
72
+ langsmith_client = None
73
+
74
+ # Load models and initialize ChromaDB at startup
75
+ @app.on_event("startup")
76
+ async def startup_event():
77
+ global response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm, openai_client, langsmith_client
78
+
79
+ # Check environment variables
80
+ if not os.environ.get("OPENAI_API_KEY"):
81
+ logger.error("OPENAI_API_KEY not set in .env file")
82
+ raise HTTPException(status_code=500, detail="OPENAI_API_KEY not set in .env file")
83
+ if not os.environ.get("LANGCHAIN_API_KEY"):
84
+ logger.error("LANGCHAIN_API_KEY not set in .env file")
85
+ raise HTTPException(status_code=500, detail="LANGCHAIN_API_KEY not set in .env file")
86
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
87
+ os.environ["LANGCHAIN_PROJECT"] = "MentalHealthCounselorPOC"
88
+
89
+ # Initialize LangSmith client
90
+ logger.info("Initializing LangSmith client")
91
+ langsmith_client = Client()
92
+
93
+ # Load saved components
94
+ logger.info("Loading model artifacts")
95
+ try:
96
+ response_clf = joblib.load(f"{output_dir}/response_type_classifier.pkl")
97
+ crisis_clf = joblib.load(f"{output_dir}/crisis_classifier.pkl")
98
+ vectorizer = joblib.load(f"{output_dir}/tfidf_vectorizer.pkl")
99
+ le = joblib.load(f"{output_dir}/label_encoder.pkl")
100
+ selector = joblib.load(f"{output_dir}/feature_selector.pkl")
101
+
102
+ try:
103
+ lda = joblib.load(f"{output_dir}/lda_model.pkl")
104
+ except Exception as lda_error:
105
+ logger.warning(f"Failed to load LDA model: {lda_error}. Creating placeholder model.")
106
+ from sklearn.decomposition import LatentDirichletAllocation
107
+ lda = LatentDirichletAllocation(n_components=10, random_state=42)
108
+ # Note: Placeholder is untrained; retrain for accurate results
109
+
110
+ except FileNotFoundError as e:
111
+ logger.error(f"Missing model artifact: {e}")
112
+ raise HTTPException(status_code=500, detail=f"Missing model artifact: {e}")
113
+
114
+ # Initialize ChromaDB
115
+ chroma_db_path = f"{output_dir}/chroma_db"
116
+ if not os.path.exists(chroma_db_path):
117
+ logger.error(f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
118
+ raise HTTPException(status_code=500, detail=f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
119
+
120
+ try:
121
+ logger.info("Initializing ChromaDB")
122
+ chroma_client = chromadb.PersistentClient(
123
+ path=chroma_db_path,
124
+ settings=Settings(anonymized_telemetry=False)
125
+ )
126
+
127
+ embeddings = OpenAIEmbeddings(
128
+ model="text-embedding-ada-002",
129
+ api_key=os.environ["OPENAI_API_KEY"],
130
+ disallowed_special=(),
131
+ chunk_size=1000
132
+ )
133
+ global vector_store
134
+ vector_store = Chroma(
135
+ client=chroma_client,
136
+ collection_name="mental_health_conversations",
137
+ embedding_function=embeddings
138
+ )
139
+ except Exception as e:
140
+ logger.error(f"Error initializing ChromaDB: {e}")
141
+ raise HTTPException(status_code=500, detail=f"Error initializing ChromaDB: {e}")
142
+
143
+ # Initialize OpenAI client and LLM
144
+ logger.info("Initializing OpenAI client and LLM")
145
+ global openai_client, llm
146
+ openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
147
+ llm = ChatOpenAI(
148
+ model="gpt-4o-mini",
149
+ temperature=0.7,
150
+ api_key=os.environ["OPENAI_API_KEY"]
151
+ )
152
+
153
+ # Pydantic model for request
154
+ class PatientContext(BaseModel):
155
+ context: str
156
+
157
+ # New Pydantic models for expanded API functionality
158
+ class UserProfile(BaseModel):
159
+ user_id: Optional[str] = None
160
+ username: str
161
+ name: str
162
+ role: str = "counselor"
163
+ specializations: List[str] = []
164
+ years_experience: Optional[int] = None
165
+ custom_crisis_keywords: List[str] = []
166
+ preferences: Dict[str, Any] = {}
167
+ created_at: Optional[datetime] = None
168
+ updated_at: Optional[datetime] = None
169
+
170
+ class SessionData(BaseModel):
171
+ session_id: Optional[str] = None
172
+ counselor_id: str
173
+ patient_identifier: str # Anonymized ID
174
+ session_notes: str = ""
175
+ session_preferences: Dict[str, Any] = {}
176
+ crisis_keywords: List[str] = []
177
+ created_at: Optional[datetime] = None
178
+ updated_at: Optional[datetime] = None
179
+
180
+ class ConversationEntry(BaseModel):
181
+ session_id: str
182
+ message: str
183
+ sender: str # 'patient' or 'counselor'
184
+ timestamp: Optional[datetime] = None
185
+ suggested_response: Optional[str] = None
186
+ response_type: Optional[str] = None
187
+ crisis_flag: bool = False
188
+ risk_level: Optional[str] = None
189
+
190
+ class FeedbackData(BaseModel):
191
+ suggestion_id: str
192
+ counselor_id: str
193
+ rating: int # 1-5 scale
194
+ was_effective: bool
195
+ comments: Optional[str] = None
196
+
197
+ class AnalysisRequest(BaseModel):
198
+ text: str
199
+ patient_background: Optional[Dict[str, Any]] = None
200
+ patient_age: Optional[int] = None
201
+ cultural_context: Optional[str] = None
202
+
203
+ class MultiModalInput(BaseModel):
204
+ session_id: str
205
+ counselor_id: str
206
+ input_type: str # 'text', 'audio', 'video'
207
+ content: str # Text content or file path/url
208
+ metadata: Dict[str, Any] = {}
209
+
210
+ class InterventionRequest(BaseModel):
211
+ patient_issue: str
212
+ patient_background: Optional[Dict[str, Any]] = None
213
+ intervention_type: Optional[str] = None # e.g., 'CBT', 'DBT', 'mindfulness'
214
+
215
+ # Text preprocessing function
216
+ @traceable(run_type="tool", name="Clean Text")
217
+ def clean_text(text):
218
+ if pd.isna(text):
219
+ return ""
220
+ text = str(text).lower()
221
+ text = re.sub(r"[^a-zA-Z']", " ", text)
222
+ tokens = word_tokenize(text)
223
+ tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
224
+ return " ".join(tokens)
225
+
226
+ # Feature engineering function
227
+ @traceable(run_type="tool", name="Engineer Features")
228
+ def engineer_features(context, response=""):
229
+ context_clean = clean_text(context)
230
+ context_len = len(context_clean.split())
231
+ context_vader = analyzer.polarity_scores(context)['compound']
232
+ context_questions = context.count('?')
233
+ crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
234
+ context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
235
+
236
+ context_tfidf = vectorizer.transform([context_clean]).toarray()
237
+ tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
238
+ response_tfidf = np.zeros_like(context_tfidf)
239
+
240
+ lda_topics = lda.transform(context_tfidf)
241
+
242
+ feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
243
+ [f"topic_{i}" for i in range(10)] + tfidf_cols + \
244
+ [f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
245
+
246
+ features = pd.DataFrame({
247
+ "context_len": [context_len],
248
+ "context_vader": [context_vader],
249
+ "context_questions": [context_questions],
250
+ **{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
251
+ **{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
252
+ **{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
253
+ })
254
+
255
+ crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
256
+ crisis_flag = crisis_clf.predict(crisis_features)[0]
257
+ if context_crisis_score > 0:
258
+ crisis_flag = 1
259
+ features["crisis_flag"] = crisis_flag
260
+
261
+ return features, feature_cols
262
+
263
+ # Prediction function
264
+ @traceable(run_type="chain", name="Predict Response Type")
265
+ def predict_response_type(context):
266
+ features, feature_cols = engineer_features(context)
267
+ selected_features = selector.transform(features[feature_cols])
268
+ pred_encoded = response_clf.predict(selected_features)[0]
269
+ pred_label = le.inverse_transform([pred_encoded])[0]
270
+ confidence = response_clf.predict_proba(selected_features)[0].max()
271
+
272
+ if "?" in context and context.count("?") > 0:
273
+ pred_label = "Question"
274
+ if "trying" in context.lower() and "hard" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
275
+ pred_label = "Validation"
276
+ if "trying" in context.lower() and "positive" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
277
+ pred_label = "Question"
278
+
279
+ crisis_flag = bool(features["crisis_flag"].iloc[0])
280
+
281
  return {
282
+ "response_type": pred_label,
283
+ "crisis_flag": crisis_flag,
284
+ "confidence": confidence,
285
+ "features": features.to_dict()
286
  }
287
 
288
+ # RAG suggestion function
289
+ @traceable(run_type="chain", name="RAG Suggestion")
290
+ def generate_suggestion_rag(context, response_type, crisis_flag):
291
+ results = vector_store.similarity_search_with_score(context, k=3)
292
+ retrieved_contexts = [
293
+ f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
294
+ for res in results
295
+ ]
296
+
297
+ prompt_template = ChatPromptTemplate.from_template(
298
+ """
299
+ You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
300
+
301
+ Patient Situation: {context}
302
+
303
+ Predicted Response Type: {response_type}
304
+ Crisis Flag: {crisis_flag}
305
+
306
+ Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
307
+
308
+ For reference, here are similar cases from past conversations:
309
+ {retrieved_contexts}
310
+
311
+ Guidelines:
312
+ - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
313
+ - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
314
+ - For 'Advice', provide practical, actionable suggestions.
315
+ - For 'Question', pose an open-ended question to encourage further discussion.
316
+ - For 'Validation', affirm the patient's efforts or feelings.
317
+
318
+ Output in the following format:
319
+ ```json
320
+ {{
321
+ "suggested_response": "Your suggested response here",
322
+ "risk_level": "Low/Moderate/High"
323
+ }}
324
+ ```
325
+ """
326
+ )
327
+
328
+ rag_chain = (
329
+ {
330
+ "context": RunnablePassthrough(),
331
+ "response_type": lambda x: response_type,
332
+ "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
333
+ "retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
334
+ }
335
+ | prompt_template
336
+ | llm
337
+ )
338
+
339
+ try:
340
+ response = rag_chain.invoke(context)
341
+ return eval(response.content.strip("```json\n").strip("\n```"))
342
+ except Exception as e:
343
+ logger.error(f"Error generating RAG suggestion: {e}")
344
+ raise HTTPException(status_code=500, detail=f"Error generating RAG suggestion: {str(e)}")
345
+
346
+ # Direct suggestion function
347
+ @traceable(run_type="chain", name="Direct Suggestion")
348
+ def generate_suggestion_direct(context, response_type, crisis_flag):
349
+ prompt_template = ChatPromptTemplate.from_template(
350
+ """
351
+ You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
352
+
353
+ Patient Situation: {context}
354
+
355
+ Predicted Response Type: {response_type}
356
+ Crisis Flag: {crisis_flag}
357
+
358
+ Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
359
+
360
+ Guidelines:
361
+ - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
362
+ - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
363
+ - For 'Advice', provide practical, actionable suggestions.
364
+ - For 'Question', pose an open-ended question to encourage further discussion.
365
+ - For 'Validation', affirm the patient's efforts or feelings.
366
+ - Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
367
+
368
+ Output in the following format:
369
+ ```json
370
+ {{
371
+ "suggested_response": "Your suggested response here",
372
+ "risk_level": "Low/Moderate/High"
373
+ }}
374
+ ```
375
+ """
376
+ )
377
+
378
+ direct_chain = (
379
+ {
380
+ "context": RunnablePassthrough(),
381
+ "response_type": lambda x: response_type,
382
+ "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
383
+ }
384
+ | prompt_template
385
+ | llm
386
+ )
387
+
388
+ try:
389
+ response = direct_chain.invoke(context)
390
+ return eval(response.content.strip("```json\n").strip("\n```"))
391
+ except Exception as e:
392
+ logger.error(f"Error generating direct suggestion: {e}")
393
+ raise HTTPException(status_code=500, detail=f"Error generating direct suggestion: {str(e)}")
394
+
395
+ # User Profile Endpoints
396
+ @app.post("/users/create", response_model=UserProfile)
397
+ async def create_user(profile: UserProfile):
398
+ """Create a new counselor profile with preferences and specializations."""
399
+ try:
400
+ saved_profile = save_user_profile(profile)
401
+ logger.info(f"Created user profile: {saved_profile.user_id}")
402
+ return saved_profile
403
+ except Exception as e:
404
+ logger.error(f"Error creating user profile: {e}")
405
+ raise HTTPException(status_code=500, detail=f"Error creating user profile: {str(e)}")
406
+
407
+ @app.get("/users/{user_id}", response_model=UserProfile)
408
+ async def get_user(user_id: str):
409
+ """Get a counselor profile by user ID."""
410
+ profile = get_user_profile(user_id)
411
+ if not profile:
412
+ raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
413
+ return profile
414
+
415
+ @app.put("/users/{user_id}", response_model=UserProfile)
416
+ async def update_user(user_id: str, profile_update: UserProfile):
417
+ """Update a counselor profile."""
418
+ existing_profile = get_user_profile(user_id)
419
+ if not existing_profile:
420
+ raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
421
+
422
+ # Preserve the original user_id
423
+ profile_update.user_id = user_id
424
+ # Preserve the original created_at timestamp
425
+ profile_update.created_at = existing_profile.created_at
426
+
427
+ try:
428
+ updated_profile = save_user_profile(profile_update)
429
+ logger.info(f"Updated user profile: {user_id}")
430
+ return updated_profile
431
+ except Exception as e:
432
+ logger.error(f"Error updating user profile: {e}")
433
+ raise HTTPException(status_code=500, detail=f"Error updating user profile: {str(e)}")
434
+
435
+ # Session Management Endpoints
436
+ @app.post("/sessions/create", response_model=SessionData)
437
+ async def create_session(session_data: SessionData):
438
+ """Create a new session with patient identifier (anonymized)."""
439
+ try:
440
+ # Verify counselor exists
441
+ counselor = get_user_profile(session_data.counselor_id)
442
+ if not counselor:
443
+ raise HTTPException(status_code=404, detail=f"Counselor not found: {session_data.counselor_id}")
444
+
445
+ # If counselor has custom crisis keywords, add them to the session
446
+ if counselor.custom_crisis_keywords:
447
+ session_data.crisis_keywords.extend(counselor.custom_crisis_keywords)
448
+
449
+ saved_session = save_session(session_data)
450
+ logger.info(f"Created session: {saved_session.session_id}")
451
+ return saved_session
452
+ except HTTPException:
453
+ raise
454
+ except Exception as e:
455
+ logger.error(f"Error creating session: {e}")
456
+ raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}")
457
+
458
+ @app.get("/sessions/{session_id}", response_model=SessionData)
459
+ async def get_session_by_id(session_id: str):
460
+ """Get a session by ID."""
461
+ session = get_session(session_id)
462
+ if not session:
463
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
464
+ return session
465
+
466
+ @app.get("/sessions/counselor/{counselor_id}", response_model=List[SessionData])
467
+ async def get_counselor_sessions(counselor_id: str):
468
+ """Get all sessions for a counselor."""
469
+ sessions = get_user_sessions(counselor_id)
470
+ return sessions
471
+
472
+ @app.put("/sessions/{session_id}", response_model=SessionData)
473
+ async def update_session(session_id: str, session_update: SessionData):
474
+ """Update a session."""
475
+ existing_session = get_session(session_id)
476
+ if not existing_session:
477
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
478
+
479
+ # Preserve the original session_id and created_at
480
+ session_update.session_id = session_id
481
+ session_update.created_at = existing_session.created_at
482
+
483
+ try:
484
+ updated_session = save_session(session_update)
485
+ logger.info(f"Updated session: {session_id}")
486
+ return updated_session
487
+ except Exception as e:
488
+ logger.error(f"Error updating session: {e}")
489
+ raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}")
490
+
491
+ # Conversation History Endpoints
492
+ @app.post("/conversations/add", response_model=str)
493
+ async def add_conversation_entry(entry: ConversationEntry):
494
+ """Add a new entry to a conversation."""
495
+ try:
496
+ # Verify session exists
497
+ session = get_session(entry.session_id)
498
+ if not session:
499
+ raise HTTPException(status_code=404, detail=f"Session not found: {entry.session_id}")
500
+
501
+ entry_id = save_conversation_entry(entry)
502
+ logger.info(f"Added conversation entry: {entry_id}")
503
+ return entry_id
504
+ except HTTPException:
505
+ raise
506
+ except Exception as e:
507
+ logger.error(f"Error adding conversation entry: {e}")
508
+ raise HTTPException(status_code=500, detail=f"Error adding conversation entry: {str(e)}")
509
+
510
+ @app.get("/conversations/{session_id}", response_model=List[ConversationEntry])
511
+ async def get_conversation(session_id: str):
512
+ """Get conversation history for a session."""
513
+ try:
514
+ # Verify session exists
515
+ session = get_session(session_id)
516
+ if not session:
517
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
518
+
519
+ entries = get_conversation_history(session_id)
520
+ return entries
521
+ except HTTPException:
522
+ raise
523
+ except Exception as e:
524
+ logger.error(f"Error retrieving conversation history: {e}")
525
+ raise HTTPException(status_code=500, detail=f"Error retrieving conversation history: {str(e)}")
526
+
527
+ # API Endpoints
528
+ @app.post("/suggest")
529
+ async def get_suggestion(context: PatientContext):
530
+ logger.info(f"Received suggestion request for context: {context.context}")
531
+ prediction = predict_response_type(context.context)
532
+ suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
533
+ suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
534
+
535
+ return {
536
+ "context": context.context,
537
+ "response_type": prediction["response_type"],
538
+ "crisis_flag": prediction["crisis_flag"],
539
+ "confidence": prediction["confidence"],
540
+ "rag_suggestion": suggestion_rag["suggested_response"],
541
+ "rag_risk_level": suggestion_rag["risk_level"],
542
+ "direct_suggestion": suggestion_direct["suggested_response"],
543
+ "direct_risk_level": suggestion_direct["risk_level"]
544
+ }
545
+
546
+ @app.post("/session/suggest")
547
+ async def get_session_suggestion(request: dict):
548
+ """Get suggestion within a session context, with enhanced crisis detection based on session keywords."""
549
+ try:
550
+ session_id = request.get("session_id")
551
+ if not session_id:
552
+ raise HTTPException(status_code=400, detail="session_id is required")
553
+
554
+ context = request.get("context")
555
+ if not context:
556
+ raise HTTPException(status_code=400, detail="context is required")
557
+
558
+ # Get session for custom crisis keywords
559
+ session = get_session(session_id)
560
+ if not session:
561
+ raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
562
+
563
+ # Get conversation history for context
564
+ conversation_history = get_conversation_history(session_id)
565
+
566
+ # Regular prediction
567
+ prediction = predict_response_type(context)
568
+ crisis_flag = prediction["crisis_flag"]
569
+
570
+ # Enhanced crisis detection with custom keywords
571
+ if not crisis_flag and session.crisis_keywords:
572
+ for keyword in session.crisis_keywords:
573
+ if keyword.lower() in context.lower():
574
+ crisis_flag = True
575
+ logger.info(f"Crisis flag triggered by custom keyword: {keyword}")
576
+ break
577
+
578
+ # Generate suggestions
579
+ suggestion_rag = generate_suggestion_rag(context, prediction["response_type"], crisis_flag)
580
+ suggestion_direct = generate_suggestion_direct(context, prediction["response_type"], crisis_flag)
581
+
582
+ # Create response
583
+ response = {
584
+ "context": context,
585
+ "response_type": prediction["response_type"],
586
+ "crisis_flag": crisis_flag,
587
+ "confidence": prediction["confidence"],
588
+ "rag_suggestion": suggestion_rag["suggested_response"],
589
+ "rag_risk_level": suggestion_rag["risk_level"],
590
+ "direct_suggestion": suggestion_direct["suggested_response"],
591
+ "direct_risk_level": suggestion_direct["risk_level"],
592
+ "session_id": session_id
593
+ }
594
+
595
+ # Save the conversation entry
596
+ entry = ConversationEntry(
597
+ session_id=session_id,
598
+ message=context,
599
+ sender="patient",
600
+ suggested_response=suggestion_rag["suggested_response"],
601
+ response_type=prediction["response_type"],
602
+ crisis_flag=crisis_flag,
603
+ risk_level=suggestion_rag["risk_level"]
604
+ )
605
+ save_conversation_entry(entry)
606
+
607
+ return response
608
+ except HTTPException:
609
+ raise
610
+ except Exception as e:
611
+ logger.error(f"Error getting session suggestion: {e}")
612
+ raise HTTPException(status_code=500, detail=f"Error getting session suggestion: {str(e)}")
613
+
614
+ # Feedback Endpoints
615
+ @app.post("/feedback")
616
+ async def add_feedback(feedback: FeedbackData):
617
+ """Add feedback about a suggestion's effectiveness."""
618
+ try:
619
+ feedback_id = save_feedback(feedback)
620
+ logger.info(f"Added feedback: {feedback_id}")
621
+ return {"feedback_id": feedback_id}
622
+ except Exception as e:
623
+ logger.error(f"Error adding feedback: {e}")
624
+ raise HTTPException(status_code=500, detail=f"Error adding feedback: {str(e)}")
625
+
626
+ # Tone & Cultural Sensitivity Analysis
627
+ @traceable(run_type="chain", name="Cultural Sensitivity Analysis")
628
+ def analyze_cultural_sensitivity(text: str, cultural_context: Optional[str] = None):
629
+ """Analyze text for cultural appropriateness and sensitivity."""
630
+ prompt_template = ChatPromptTemplate.from_template(
631
+ """
632
+ You are a cultural sensitivity expert. Analyze the following text for cultural appropriateness:
633
+
634
+ Text: {text}
635
+
636
+ Cultural Context: {cultural_context}
637
+
638
+ Provide an analysis of:
639
+ 1. Cultural appropriateness
640
+ 2. Potential bias or insensitivity
641
+ 3. Suggestions for improvement
642
+
643
+ Output in the following format:
644
+ ```json
645
+ {{
646
+ "cultural_appropriateness_score": 0-10,
647
+ "issues_detected": ["issue1", "issue2"],
648
+ "suggestions": ["suggestion1", "suggestion2"],
649
+ "explanation": "Brief explanation of analysis"
650
+ }}
651
+ ```
652
+ """
653
+ )
654
+
655
+ analysis_chain = (
656
+ {
657
+ "text": RunnablePassthrough(),
658
+ "cultural_context": lambda x: cultural_context if cultural_context else "General"
659
+ }
660
+ | prompt_template
661
+ | llm
662
+ )
663
+
664
+ try:
665
+ response = analysis_chain.invoke(text)
666
+ return eval(response.content.strip("```json\n").strip("\n```"))
667
+ except Exception as e:
668
+ logger.error(f"Error analyzing cultural sensitivity: {e}")
669
+ raise HTTPException(status_code=500, detail=f"Error analyzing cultural sensitivity: {str(e)}")
670
+
671
+ @traceable(run_type="chain", name="Age Appropriate Analysis")
672
+ def analyze_age_appropriateness(text: str, age: Optional[int] = None):
673
+ """Analyze text for age-appropriate language."""
674
+ prompt_template = ChatPromptTemplate.from_template(
675
+ """
676
+ You are an expert in age-appropriate communication. Analyze the following text for age appropriateness:
677
+
678
+ Text: {text}
679
+
680
+ Target Age: {age}
681
+
682
+ Provide an analysis of:
683
+ 1. Age appropriateness
684
+ 2. Complexity level
685
+ 3. Suggestions for improvement
686
+
687
+ Output in the following format:
688
+ ```json
689
+ {{
690
+ "age_appropriateness_score": 0-10,
691
+ "complexity_level": "Simple/Moderate/Complex",
692
+ "issues_detected": ["issue1", "issue2"],
693
+ "suggestions": ["suggestion1", "suggestion2"],
694
+ "explanation": "Brief explanation of analysis"
695
+ }}
696
+ ```
697
+ """
698
+ )
699
+
700
+ analysis_chain = (
701
+ {
702
+ "text": RunnablePassthrough(),
703
+ "age": lambda x: str(age) if age else "Adult"
704
+ }
705
+ | prompt_template
706
+ | llm
707
+ )
708
+
709
+ try:
710
+ response = analysis_chain.invoke(text)
711
+ return eval(response.content.strip("```json\n").strip("\n```"))
712
+ except Exception as e:
713
+ logger.error(f"Error analyzing age appropriateness: {e}")
714
+ raise HTTPException(status_code=500, detail=f"Error analyzing age appropriateness: {str(e)}")
715
+
716
+ @app.post("/analyze/sensitivity")
717
+ async def analyze_text_sensitivity(request: AnalysisRequest):
718
+ """Analyze text for cultural sensitivity and age appropriateness."""
719
+ try:
720
+ cultural_analysis = analyze_cultural_sensitivity(request.text, request.cultural_context)
721
+ age_analysis = analyze_age_appropriateness(request.text, request.patient_age)
722
+
723
+ return {
724
+ "text": request.text,
725
+ "cultural_analysis": cultural_analysis,
726
+ "age_analysis": age_analysis
727
+ }
728
+ except Exception as e:
729
+ logger.error(f"Error analyzing text sensitivity: {e}")
730
+ raise HTTPException(status_code=500, detail=f"Error analyzing text sensitivity: {str(e)}")
731
+
732
+ # Guided Intervention Workflows
733
+ @traceable(run_type="chain", name="Generate Intervention")
734
+ def generate_intervention_workflow(issue: str, intervention_type: Optional[str] = None, background: Optional[Dict] = None):
735
+ """Generate a structured intervention workflow for a specific issue."""
736
+ prompt_template = ChatPromptTemplate.from_template(
737
+ """
738
+ You are an expert mental health counselor. Generate a structured intervention workflow for the following patient issue:
739
+
740
+ Patient Issue: {issue}
741
+
742
+ Intervention Type: {intervention_type}
743
+ Patient Background: {background}
744
+
745
+ Provide a step-by-step intervention plan based on evidence-based practices. Include:
746
+ 1. Initial assessment questions
747
+ 2. Specific techniques to apply
748
+ 3. Homework or practice exercises
749
+ 4. Follow-up guidance
750
+
751
+ Output in the following format:
752
+ ```json
753
+ {{
754
+ "intervention_type": "CBT/DBT/ACT/Mindfulness/etc.",
755
+ "assessment_questions": ["question1", "question2", "question3"],
756
+ "techniques": [
757
+ {{
758
+ "name": "technique name",
759
+ "description": "brief description",
760
+ "instructions": "step-by-step instructions"
761
+ }}
762
+ ],
763
+ "exercises": [
764
+ {{
765
+ "name": "exercise name",
766
+ "description": "brief description",
767
+ "instructions": "step-by-step instructions"
768
+ }}
769
+ ],
770
+ "follow_up": ["follow-up step 1", "follow-up step 2"],
771
+ "resources": ["resource1", "resource2"]
772
+ }}
773
+ ```
774
+ """
775
+ )
776
+
777
+ intervention_chain = (
778
+ {
779
+ "issue": RunnablePassthrough(),
780
+ "intervention_type": lambda x: intervention_type if intervention_type else "Best fit",
781
+ "background": lambda x: str(background) if background else "Not provided"
782
+ }
783
+ | prompt_template
784
+ | llm
785
+ )
786
+
787
+ try:
788
+ response = intervention_chain.invoke(issue)
789
+ return eval(response.content.strip("```json\n").strip("\n```"))
790
+ except Exception as e:
791
+ logger.error(f"Error generating intervention workflow: {e}")
792
+ raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
793
+
794
+ @app.post("/interventions/generate")
795
+ async def get_intervention_workflow(request: InterventionRequest):
796
+ """Get a structured intervention workflow for a specific patient issue."""
797
+ try:
798
+ intervention = generate_intervention_workflow(
799
+ request.patient_issue,
800
+ request.intervention_type,
801
+ request.patient_background
802
+ )
803
+
804
+ return {
805
+ "patient_issue": request.patient_issue,
806
+ "intervention": intervention
807
+ }
808
+ except Exception as e:
809
+ logger.error(f"Error generating intervention workflow: {e}")
810
+ raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
811
+
812
  @app.get("/health")
813
  async def health_check():
814
+ if all([response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm]):
815
+ return {"status": "healthy", "message": "All models and vector store loaded successfully"}
816
+ logger.error("Health check failed: One or more components not loaded")
817
+ raise HTTPException(status_code=500, detail="One or more components failed to load")
818
 
 
819
  @app.get("/metadata")
820
  async def get_metadata():
821
+ try:
822
+ collection = vector_store._client.get_collection("mental_health_conversations")
823
+ count = collection.count()
824
+ return {"collection_name": "mental_health_conversations", "document_count": count}
825
+ except Exception as e:
826
+ logger.error(f"Error retrieving metadata: {e}")
827
+ raise HTTPException(status_code=500, detail=f"Error retrieving metadata: {str(e)}")
 
 
 
 
 
828
 
829
+ # Database utility functions
830
+ def save_user_profile(profile: UserProfile):
831
+ if not profile.user_id:
832
+ profile.user_id = str(uuid4())
 
833
 
834
+ if not profile.created_at:
835
+ profile.created_at = datetime.now()
836
+
837
+ profile.updated_at = datetime.now()
838
+
839
+ file_path = os.path.join(DATA_DIR, "users", f"{profile.user_id}.json")
840
+ with open(file_path, "w") as f:
841
+ # Convert datetime to string for JSON serialization
842
+ profile_dict = profile.dict()
843
+ for key in ["created_at", "updated_at"]:
844
+ if profile_dict[key]:
845
+ profile_dict[key] = profile_dict[key].isoformat()
846
+ f.write(json.dumps(profile_dict, indent=2))
847
+
848
+ return profile
849
+
850
+ def get_user_profile(user_id: str) -> Optional[UserProfile]:
851
+ file_path = os.path.join(DATA_DIR, "users", f"{user_id}.json")
852
+ if not os.path.exists(file_path):
853
+ return None
854
+
855
+ with open(file_path, "r") as f:
856
+ data = json.loads(f.read())
857
+ # Convert string dates back to datetime
858
+ for key in ["created_at", "updated_at"]:
859
+ if data[key]:
860
+ data[key] = datetime.fromisoformat(data[key])
861
+ return UserProfile(**data)
862
+
863
+ def save_session(session: SessionData):
864
+ if not session.session_id:
865
+ session.session_id = str(uuid4())
866
+
867
+ if not session.created_at:
868
+ session.created_at = datetime.now()
869
+
870
+ session.updated_at = datetime.now()
871
+
872
+ file_path = os.path.join(DATA_DIR, "sessions", f"{session.session_id}.json")
873
+ with open(file_path, "w") as f:
874
+ # Convert datetime to string for JSON serialization
875
+ session_dict = session.dict()
876
+ for key in ["created_at", "updated_at"]:
877
+ if session_dict[key]:
878
+ session_dict[key] = session_dict[key].isoformat()
879
+ f.write(json.dumps(session_dict, indent=2))
880
+
881
+ return session
882
+
883
+ def get_session(session_id: str) -> Optional[SessionData]:
884
+ file_path = os.path.join(DATA_DIR, "sessions", f"{session_id}.json")
885
+ if not os.path.exists(file_path):
886
+ return None
887
+
888
+ with open(file_path, "r") as f:
889
+ data = json.loads(f.read())
890
+ # Convert string dates back to datetime
891
+ for key in ["created_at", "updated_at"]:
892
+ if data[key]:
893
+ data[key] = datetime.fromisoformat(data[key])
894
+ return SessionData(**data)
895
+
896
+ def get_user_sessions(counselor_id: str) -> List[SessionData]:
897
+ sessions = []
898
+ sessions_dir = os.path.join(DATA_DIR, "sessions")
899
+ for filename in os.listdir(sessions_dir):
900
+ if not filename.endswith(".json"):
901
+ continue
902
+
903
+ file_path = os.path.join(sessions_dir, filename)
904
+ with open(file_path, "r") as f:
905
+ data = json.loads(f.read())
906
+ if data["counselor_id"] == counselor_id:
907
+ for key in ["created_at", "updated_at"]:
908
+ if data[key]:
909
+ data[key] = datetime.fromisoformat(data[key])
910
+ sessions.append(SessionData(**data))
911
+
912
+ return sessions
913
+
914
+ def save_conversation_entry(entry: ConversationEntry):
915
+ conversation_dir = os.path.join(DATA_DIR, "conversations", entry.session_id)
916
+ os.makedirs(conversation_dir, exist_ok=True)
917
+
918
+ if not entry.timestamp:
919
+ entry.timestamp = datetime.now()
920
+
921
+ entry_id = str(uuid4())
922
+ file_path = os.path.join(conversation_dir, f"{entry_id}.json")
923
+
924
+ with open(file_path, "w") as f:
925
+ # Convert datetime to string for JSON serialization
926
+ entry_dict = entry.dict()
927
+ entry_dict["entry_id"] = entry_id
928
+ if entry_dict["timestamp"]:
929
+ entry_dict["timestamp"] = entry_dict["timestamp"].isoformat()
930
+ f.write(json.dumps(entry_dict, indent=2))
931
+
932
+ return entry_id
933
+
934
+ def get_conversation_history(session_id: str) -> List[ConversationEntry]:
935
+ conversation_dir = os.path.join(DATA_DIR, "conversations", session_id)
936
+ if not os.path.exists(conversation_dir):
937
+ return []
938
+
939
+ entries = []
940
+ for filename in os.listdir(conversation_dir):
941
+ if not filename.endswith(".json"):
942
+ continue
943
+
944
+ file_path = os.path.join(conversation_dir, filename)
945
+ with open(file_path, "r") as f:
946
+ data = json.loads(f.read())
947
+ if data["timestamp"]:
948
+ data["timestamp"] = datetime.fromisoformat(data["timestamp"])
949
+ entries.append(ConversationEntry(**data))
950
+
951
+ # Sort by timestamp
952
+ entries.sort(key=lambda x: x.timestamp)
953
+ return entries
954
+
955
+ def save_feedback(feedback: FeedbackData):
956
+ feedback_id = str(uuid4())
957
+ file_path = os.path.join(DATA_DIR, "feedback", f"{feedback_id}.json")
958
+
959
+ with open(file_path, "w") as f:
960
+ feedback_dict = feedback.dict()
961
+ feedback_dict["feedback_id"] = feedback_id
962
+ feedback_dict["timestamp"] = datetime.now().isoformat()
963
+ f.write(json.dumps(feedback_dict, indent=2))
964
+
965
+ return feedback_id
966
+
967
+ # Multi-modal Input Support
968
+ @app.post("/input/process")
969
+ async def process_multimodal_input(input_data: MultiModalInput):
970
+ """Process multi-modal input (text, audio, video)."""
971
+ try:
972
+ if input_data.input_type not in ["text", "audio", "video"]:
973
+ raise HTTPException(status_code=400, detail=f"Unsupported input type: {input_data.input_type}")
974
+
975
+ # For now, handle text directly and simulate processing for audio/video
976
+ if input_data.input_type == "text":
977
+ # Process text normally
978
+ prediction = predict_response_type(input_data.content)
979
+
980
+ return {
981
+ "input_type": "text",
982
+ "processed_content": input_data.content,
983
+ "analysis": {
984
+ "response_type": prediction["response_type"],
985
+ "crisis_flag": prediction["crisis_flag"],
986
+ "confidence": prediction["confidence"]
987
+ }
988
+ }
989
+ elif input_data.input_type == "audio":
990
+ # Simulate audio transcription and emotion detection
991
+ # In a production system, this would use a speech-to-text API and emotion analysis
992
+ prompt_template = ChatPromptTemplate.from_template(
993
+ """
994
+ Simulate audio processing for this description: {content}
995
+
996
+ Generate a simulated transcription and emotion detection as if this were real audio.
997
+
998
+ Output in the following format:
999
+ ```json
1000
+ {{
1001
+ "transcription": "Simulated transcription of the audio",
1002
+ "emotion_detected": "primary emotion",
1003
+ "secondary_emotions": ["emotion1", "emotion2"],
1004
+ "confidence": 0.85
1005
+ }}
1006
+ ```
1007
+ """
1008
+ )
1009
+
1010
+ process_chain = prompt_template | llm
1011
+ response = process_chain.invoke({"content": input_data.content})
1012
+ audio_result = eval(response.content.strip("```json\n").strip("\n```"))
1013
+
1014
+ # Now process the transcription
1015
+ prediction = predict_response_type(audio_result["transcription"])
1016
+
1017
+ return {
1018
+ "input_type": "audio",
1019
+ "processed_content": audio_result["transcription"],
1020
+ "emotion_analysis": {
1021
+ "primary_emotion": audio_result["emotion_detected"],
1022
+ "secondary_emotions": audio_result["secondary_emotions"],
1023
+ "confidence": audio_result["confidence"]
1024
+ },
1025
+ "analysis": {
1026
+ "response_type": prediction["response_type"],
1027
+ "crisis_flag": prediction["crisis_flag"],
1028
+ "confidence": prediction["confidence"]
1029
+ }
1030
+ }
1031
+ elif input_data.input_type == "video":
1032
+ # Simulate video analysis
1033
+ # In a production system, this would use video analytics API
1034
+ prompt_template = ChatPromptTemplate.from_template(
1035
+ """
1036
+ Simulate video processing for this description: {content}
1037
+
1038
+ Generate a simulated analysis as if this were real video with facial expressions and body language.
1039
+
1040
+ Output in the following format:
1041
+ ```json
1042
+ {{
1043
+ "transcription": "Simulated transcription of speech in the video",
1044
+ "facial_expressions": ["expression1", "expression2"],
1045
+ "body_language": ["posture observation", "gesture observation"],
1046
+ "primary_emotion": "primary emotion",
1047
+ "confidence": 0.80
1048
+ }}
1049
+ ```
1050
+ """
1051
+ )
1052
+
1053
+ process_chain = prompt_template | llm
1054
+ response = process_chain.invoke({"content": input_data.content})
1055
+ video_result = eval(response.content.strip("```json\n").strip("\n```"))
1056
+
1057
+ # Process the transcription
1058
+ prediction = predict_response_type(video_result["transcription"])
1059
+
1060
+ return {
1061
+ "input_type": "video",
1062
+ "processed_content": video_result["transcription"],
1063
+ "nonverbal_analysis": {
1064
+ "facial_expressions": video_result["facial_expressions"],
1065
+ "body_language": video_result["body_language"],
1066
+ "primary_emotion": video_result["primary_emotion"],
1067
+ "confidence": video_result["confidence"]
1068
+ },
1069
+ "analysis": {
1070
+ "response_type": prediction["response_type"],
1071
+ "crisis_flag": prediction["crisis_flag"],
1072
+ "confidence": prediction["confidence"]
1073
+ }
1074
+ }
1075
+ except Exception as e:
1076
+ logger.error(f"Error processing multimodal input: {e}")
1077
+ raise HTTPException(status_code=500, detail=f"Error processing multimodal input: {str(e)}")
1078
+
1079
+ # Therapeutic Technique Suggestions
1080
+ @traceable(run_type="chain", name="Therapeutic Techniques")
1081
+ def suggest_therapeutic_techniques(context: str, technique_type: Optional[str] = None):
1082
+ """Suggest specific therapeutic techniques based on the patient context."""
1083
+ prompt_template = ChatPromptTemplate.from_template(
1084
+ """
1085
+ You are an expert mental health professional with extensive knowledge of therapeutic techniques. Based on the following patient context, suggest therapeutic techniques that would be appropriate:
1086
+
1087
+ Patient Context: {context}
1088
+
1089
+ Technique Type (if specified): {technique_type}
1090
+
1091
+ Suggest specific therapeutic techniques, exercises, or interventions that would be helpful for this patient. Include:
1092
+ 1. Name of technique
1093
+ 2. Brief description
1094
+ 3. How to apply it in this specific case
1095
+ 4. Expected benefits
1096
+
1097
+ Provide a range of options from different therapeutic approaches (CBT, DBT, ACT, mindfulness, motivational interviewing, etc.) unless a specific technique type was requested.
1098
+
1099
+ Output in the following format:
1100
+ ```json
1101
+ {{
1102
+ "primary_approach": "The most appropriate therapeutic approach",
1103
+ "techniques": [
1104
+ {{
1105
+ "name": "Technique name",
1106
+ "approach": "CBT/DBT/ACT/etc.",
1107
+ "description": "Brief description",
1108
+ "application": "How to apply to this specific case",
1109
+ "benefits": "Expected benefits"
1110
+ }}
1111
+ ],
1112
+ "rationale": "Brief explanation of why these techniques were selected"
1113
+ }}
1114
+ ```
1115
+ """
1116
+ )
1117
+
1118
+ technique_chain = (
1119
+ {
1120
+ "context": RunnablePassthrough(),
1121
+ "technique_type": lambda x: technique_type if technique_type else "Any appropriate"
1122
  }
1123
+ | prompt_template
1124
+ | llm
1125
+ )
1126
 
1127
+ try:
1128
+ response = technique_chain.invoke(context)
1129
+ return eval(response.content.strip("```json\n").strip("\n```"))
1130
+ except Exception as e:
1131
+ logger.error(f"Error suggesting therapeutic techniques: {e}")
1132
+ raise HTTPException(status_code=500, detail=f"Error suggesting therapeutic techniques: {str(e)}")
1133
+
1134
+ @app.post("/techniques/suggest")
1135
+ async def get_therapeutic_techniques(request: dict):
1136
+ """Get suggested therapeutic techniques for a patient context."""
1137
+ try:
1138
+ context = request.get("context")
1139
+ if not context:
1140
+ raise HTTPException(status_code=400, detail="context is required")
1141
+
1142
+ technique_type = request.get("technique_type")
1143
+
1144
+ techniques = suggest_therapeutic_techniques(context, technique_type)
1145
+
1146
  return {
1147
+ "context": context,
1148
+ "techniques": techniques
 
1149
  }
1150
+ except Exception as e:
1151
+ logger.error(f"Error getting therapeutic techniques: {e}")
1152
+ raise HTTPException(status_code=500, detail=f"Error getting therapeutic techniques: {str(e)}")
1153
+
1154
+ # Ethical AI Guardrails - Confidence Indicator
1155
+ @app.post("/suggest/with_confidence")
1156
+ async def get_suggestion_with_confidence(context: PatientContext):
1157
+ """Get suggestion with detailed confidence indicators and uncertainty flags."""
1158
+ try:
1159
+ # Get standard prediction
1160
+ prediction = predict_response_type(context.context)
1161
+
1162
+ # Set confidence thresholds
1163
+ high_confidence = 0.8
1164
+ medium_confidence = 0.6
1165
+
1166
+ # Determine confidence level
1167
+ confidence_value = prediction["confidence"]
1168
+ if confidence_value >= high_confidence:
1169
+ confidence_level = "High"
1170
+ elif confidence_value >= medium_confidence:
1171
+ confidence_level = "Medium"
1172
+ else:
1173
+ confidence_level = "Low"
1174
+
1175
+ # Analyze for potential biases
1176
+ bias_prompt = ChatPromptTemplate.from_template(
1177
+ """
1178
+ You are an AI ethics expert. Analyze the following patient context and proposed response type for potential biases:
1179
+
1180
+ Patient Context: {context}
1181
+ Predicted Response Type: {response_type}
1182
+
1183
+ Identify any potential biases in interpretation or response. Consider gender, cultural, socioeconomic, and other potential biases.
1184
+
1185
+ Output in the following format:
1186
+ ```json
1187
+ {{
1188
+ "bias_detected": true/false,
1189
+ "bias_types": ["bias type 1", "bias type 2"],
1190
+ "explanation": "Brief explanation of potential biases"
1191
+ }}
1192
+ ```
1193
+ """
1194
+ )
1195
+
1196
+ bias_chain = (
1197
+ {
1198
+ "context": lambda x: context.context,
1199
+ "response_type": lambda x: prediction["response_type"]
1200
+ }
1201
+ | bias_prompt
1202
+ | llm
1203
+ )
1204
+
1205
+ bias_analysis = eval(bias_chain.invoke({}).content.strip("```json\n").strip("\n```"))
1206
+
1207
+ # Generate suggestions
1208
+ suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
1209
+ suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
1210
+
1211
+ return {
1212
+ "context": context.context,
1213
+ "response_type": prediction["response_type"],
1214
+ "crisis_flag": prediction["crisis_flag"],
1215
+ "confidence": {
1216
+ "value": prediction["confidence"],
1217
+ "level": confidence_level,
1218
+ "uncertainty_flag": confidence_level == "Low"
1219
+ },
1220
+ "bias_analysis": bias_analysis,
1221
+ "rag_suggestion": suggestion_rag["suggested_response"],
1222
+ "rag_risk_level": suggestion_rag["risk_level"],
1223
+ "direct_suggestion": suggestion_direct["suggested_response"],
1224
+ "direct_risk_level": suggestion_direct["risk_level"],
1225
+ "attribution": {
1226
+ "ai_generated": True,
1227
+ "model_version": "Mental Health Counselor API v2.0",
1228
+ "human_reviewed": False
1229
+ }
1230
+ }
1231
+ except Exception as e:
1232
+ logger.error(f"Error getting suggestion with confidence: {e}")
1233
+ raise HTTPException(status_code=500, detail=f"Error getting suggestion with confidence: {str(e)}")
1234
+
1235
+ # Text to Speech with Eleven Labs API
1236
+ @app.post("/api/text-to-speech")
1237
+ async def text_to_speech(request: dict):
1238
+ """Convert text to speech using Eleven Labs API."""
1239
+ try:
1240
+ text = request.get("text")
1241
+ voice_id = request.get("voice_id", "pNInz6obpgDQGcFmaJgB") # Default to "Adam" voice
1242
+
1243
+ if not text:
1244
+ raise HTTPException(status_code=400, detail="Text is required")
1245
+
1246
+ # Get API key from environment
1247
+ api_key = os.getenv("ELEVEN_API_KEY")
1248
+ if not api_key:
1249
+ raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
1250
+
1251
+ # Prepare the request to Eleven Labs
1252
+ url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
1253
+
1254
+ headers = {
1255
+ "Accept": "audio/mpeg",
1256
+ "Content-Type": "application/json",
1257
+ "xi-api-key": api_key
1258
+ }
1259
+
1260
+ payload = {
1261
+ "text": text,
1262
+ "model_id": "eleven_multilingual_v2",
1263
+ "voice_settings": {
1264
+ "stability": 0.5,
1265
+ "similarity_boost": 0.75
1266
+ }
1267
+ }
1268
+
1269
+ # Make the request to Eleven Labs
1270
+ response = requests.post(url, json=payload, headers=headers)
1271
+
1272
+ if response.status_code != 200:
1273
+ logger.error(f"Eleven Labs API error: {response.text}")
1274
+ raise HTTPException(status_code=response.status_code,
1275
+ detail=f"Eleven Labs API error: {response.text}")
1276
+
1277
+ # Return audio as streaming response
1278
+ return StreamingResponse(
1279
+ BytesIO(response.content),
1280
+ media_type="audio/mpeg"
1281
+ )
1282
+
1283
+ except Exception as e:
1284
+ logger.error(f"Error in text-to-speech: {str(e)}")
1285
+ if not isinstance(e, HTTPException):
1286
+ raise HTTPException(status_code=500, detail=f"Text-to-speech error: {str(e)}")
1287
+ raise e
1288
+
1289
+ # Multimedia file processing (speech to text)
1290
+ @app.post("/api/input/process")
1291
+ async def process_audio_input(
1292
+ audio: UploadFile = File(...),
1293
+ session_id: str = Form(...)
1294
+ ):
1295
+ """Process audio input for speech-to-text using Eleven Labs."""
1296
+ try:
1297
+ # Get API key from environment
1298
+ api_key = os.getenv("ELEVEN_API_KEY")
1299
+ if not api_key:
1300
+ raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
1301
+
1302
+ # Read the audio file content
1303
+ audio_content = await audio.read()
1304
+
1305
+ # Call Eleven Labs Speech-to-Text API
1306
+ url = "https://api.elevenlabs.io/v1/speech-to-text"
1307
+
1308
+ headers = {
1309
+ "xi-api-key": api_key
1310
+ }
1311
+
1312
+ # Create form data with the audio file
1313
+ files = {
1314
+ 'audio': ('audio.webm', audio_content, 'audio/webm')
1315
+ }
1316
+
1317
+ data = {
1318
+ 'model_id': 'whisper-1' # Using Whisper model
1319
+ }
1320
+
1321
+ # Make the request to Eleven Labs
1322
+ response = requests.post(url, headers=headers, files=files, data=data)
1323
+
1324
+ if response.status_code != 200:
1325
+ logger.error(f"Eleven Labs API error: {response.text}")
1326
+ raise HTTPException(status_code=response.status_code,
1327
+ detail=f"Eleven Labs API error: {response.text}")
1328
+
1329
+ result = response.json()
1330
+
1331
+ # Extract the transcribed text
1332
+ text = result.get('text', '')
1333
+
1334
+ # Return the transcribed text
1335
+ return {
1336
+ "text": text,
1337
+ "session_id": session_id
1338
+ }
1339
+
1340
+ except Exception as e:
1341
+ logger.error(f"Error processing audio: {str(e)}")
1342
+ if not isinstance(e, HTTPException):
1343
+ raise HTTPException(status_code=500, detail=f"Audio processing error: {str(e)}")
1344
+ raise e
1345
+
1346
+ # Add a custom encoder for bytes objects to prevent UTF-8 decode errors
1347
+ def custom_encoder(obj):
1348
+ if isinstance(obj, bytes):
1349
+ try:
1350
+ return obj.decode('utf-8')
1351
+ except UnicodeDecodeError:
1352
+ return base64.b64encode(obj).decode('ascii')
1353
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
1354
+
1355
+ # Override the jsonable_encoder function to handle bytes properly
1356
+ from fastapi.encoders import jsonable_encoder as original_jsonable_encoder
1357
+
1358
+ def safe_jsonable_encoder(*args, **kwargs):
1359
+ try:
1360
+ return original_jsonable_encoder(*args, **kwargs)
1361
+ except UnicodeDecodeError:
1362
+ # If the standard encoder fails with a decode error,
1363
+ # ensure all bytes are properly handled
1364
+ if args and isinstance(args[0], bytes):
1365
+ return custom_encoder(args[0])
1366
+ raise
1367
+
1368
+ # Monkey patch the jsonable_encoder in FastAPI
1369
+ import fastapi.encoders
1370
+ fastapi.encoders.jsonable_encoder = safe_jsonable_encoder
requirements.txt CHANGED
@@ -1,37 +1,38 @@
1
  # Core dependencies for data processing and ML
2
- pandas==2.0.3
3
- numpy==1.24.3
4
- scikit-learn==1.2.2
5
- joblib==1.3.1
6
 
7
  # NLP and sentiment analysis
8
- nltk==3.8.1
9
- vaderSentiment==3.3.2
10
 
11
  # Dataset downloading
12
- kagglehub
13
 
14
  # Vector database and embeddings
15
- chromadb
16
- openai
17
- langchain
18
- langchain-openai
19
- langchain-chroma
20
- httpx==0.24.1
21
 
22
  # API and tracing
23
- fastapi==0.95.2
24
- uvicorn[standard]==0.22.0
25
- pydantic==1.10.8
26
- langsmith
27
- python-dotenv==1.0.0
28
- lightgbm
 
29
 
30
  # New dependencies for additional features
31
- python-multipart==0.0.6
32
  fastapi-cors # For CORS support
33
- aiofiles==23.1.0
34
- jinja2==3.1.2
35
  python-jose[cryptography] # For JWT tokens (authentication)
36
  passlib[bcrypt] # For password hashing
37
  pydub # For audio processing
 
1
  # Core dependencies for data processing and ML
2
+ pandas>=1.3.0,<2.1.0
3
+ numpy>=1.20.0,<1.25.0
4
+ scikit-learn>=1.0.0,<1.3.0
5
+ joblib>=1.0.0,<1.4.0
6
 
7
  # NLP and sentiment analysis
8
+ nltk>=3.6.0,<3.9.0
9
+ vaderSentiment>=3.3.0,<3.4.0
10
 
11
  # Dataset downloading
12
+ kagglehub>=0.3.0
13
 
14
  # Vector database and embeddings
15
+ # Using a compatible version of chromadb
16
+ chromadb>=0.4.18,<0.5.0
17
+ openai>=0.27.0
18
+ langchain>=0.0.267
19
+ langchain-openai>=0.0.1
20
+ langchain-chroma>=0.0.1
21
 
22
  # API and tracing
23
+ fastapi>=0.95.0,<0.96.0
24
+ uvicorn[standard]>=0.22.0,<0.23.0
25
+ pydantic>=1.10.0,<1.11.0
26
+ langsmith>=0.0.52
27
+ python-dotenv>=1.0.0
28
+ httpx>=0.24.0,<0.25.0
29
+ lightgbm>=4.0.0
30
 
31
  # New dependencies for additional features
32
+ python-multipart>=0.0.6
33
  fastapi-cors # For CORS support
34
+ aiofiles>=23.1.0
35
+ jinja2>=3.1.2
36
  python-jose[cryptography] # For JWT tokens (authentication)
37
  passlib[bcrypt] # For password hashing
38
  pydub # For audio processing