Said Lfagrouche commited on
Commit
e62f11c
·
1 Parent(s): a059233

Prepare for Hugging Face Spaces deployment with simplified configuration

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +45 -1333
  3. requirements.txt +22 -23
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.12
2
 
3
  WORKDIR /app
4
 
 
1
+ FROM python:3.9
2
 
3
  WORKDIR /app
4
 
app.py CHANGED
@@ -1,34 +1,8 @@
1
- # api_mental_health.py
2
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form
3
  from pydantic import BaseModel
4
- import pandas as pd
5
- import numpy as np
6
- import joblib
7
- import re
8
- import nltk
9
- from nltk.tokenize import word_tokenize
10
- from nltk.stem import WordNetLemmatizer
11
- from nltk.corpus import stopwords
12
- from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
13
- import chromadb
14
- from chromadb.config import Settings
15
- from langchain_openai import OpenAIEmbeddings, ChatOpenAI
16
- from langchain_chroma import Chroma
17
- from openai import OpenAI
18
  import os
19
  from dotenv import load_dotenv
20
- from langsmith import Client, traceable
21
- from langchain_core.runnables import RunnablePassthrough
22
- from langchain_core.prompts import ChatPromptTemplate
23
  import logging
24
- from typing import List, Dict, Optional, Any, Union, Annotated
25
- from datetime import datetime
26
- from uuid import uuid4, UUID
27
- import json
28
- import requests
29
- from fastapi.responses import StreamingResponse
30
- from io import BytesIO
31
- import base64
32
 
33
  # Set up logging
34
  logging.basicConfig(level=logging.INFO)
@@ -37,15 +11,10 @@ logger = logging.getLogger(__name__)
37
  # Load environment variables
38
  load_dotenv()
39
 
40
- # Download required NLTK data
41
- nltk.download('punkt')
42
- nltk.download('wordnet')
43
- nltk.download('stopwords')
44
-
45
  # Initialize FastAPI app
46
  app = FastAPI(title="Mental Health Counselor API")
47
 
48
- # Initialize global storage (to be replaced with proper database)
49
  DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
50
  os.makedirs(DATA_DIR, exist_ok=True)
51
  os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
@@ -53,1318 +22,61 @@ os.makedirs(os.path.join(DATA_DIR, "sessions"), exist_ok=True)
53
  os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
54
  os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
55
 
56
- # Initialize components
57
- STOPWORDS = set(stopwords.words("english"))
58
- lemmatizer = WordNetLemmatizer()
59
- analyzer = SentimentIntensityAnalyzer()
60
- output_dir = "mental_health_model_artifacts"
61
-
62
- # Global variables for models and vector store
63
- response_clf = None
64
- crisis_clf = None
65
- vectorizer = None
66
- le = None
67
- selector = None
68
- lda = None
69
- vector_store = None
70
- llm = None
71
- openai_client = None
72
- langsmith_client = None
73
-
74
- # Load models and initialize ChromaDB at startup
75
- @app.on_event("startup")
76
- async def startup_event():
77
- global response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm, openai_client, langsmith_client
78
-
79
- # Check environment variables
80
- if not os.environ.get("OPENAI_API_KEY"):
81
- logger.error("OPENAI_API_KEY not set in .env file")
82
- raise HTTPException(status_code=500, detail="OPENAI_API_KEY not set in .env file")
83
- if not os.environ.get("LANGCHAIN_API_KEY"):
84
- logger.error("LANGCHAIN_API_KEY not set in .env file")
85
- raise HTTPException(status_code=500, detail="LANGCHAIN_API_KEY not set in .env file")
86
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
87
- os.environ["LANGCHAIN_PROJECT"] = "MentalHealthCounselorPOC"
88
-
89
- # Initialize LangSmith client
90
- logger.info("Initializing LangSmith client")
91
- langsmith_client = Client()
92
-
93
- # Load saved components
94
- logger.info("Loading model artifacts")
95
- try:
96
- response_clf = joblib.load(f"{output_dir}/response_type_classifier.pkl")
97
- crisis_clf = joblib.load(f"{output_dir}/crisis_classifier.pkl")
98
- vectorizer = joblib.load(f"{output_dir}/tfidf_vectorizer.pkl")
99
- le = joblib.load(f"{output_dir}/label_encoder.pkl")
100
- selector = joblib.load(f"{output_dir}/feature_selector.pkl")
101
-
102
- try:
103
- lda = joblib.load(f"{output_dir}/lda_model.pkl")
104
- except Exception as lda_error:
105
- logger.warning(f"Failed to load LDA model: {lda_error}. Creating placeholder model.")
106
- from sklearn.decomposition import LatentDirichletAllocation
107
- lda = LatentDirichletAllocation(n_components=10, random_state=42)
108
- # Note: Placeholder is untrained; retrain for accurate results
109
-
110
- except FileNotFoundError as e:
111
- logger.error(f"Missing model artifact: {e}")
112
- raise HTTPException(status_code=500, detail=f"Missing model artifact: {e}")
113
-
114
- # Initialize ChromaDB
115
- chroma_db_path = f"{output_dir}/chroma_db"
116
- if not os.path.exists(chroma_db_path):
117
- logger.error(f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
118
- raise HTTPException(status_code=500, detail=f"ChromaDB not found at {chroma_db_path}. Run create_vector_db.py first.")
119
-
120
- try:
121
- logger.info("Initializing ChromaDB")
122
- chroma_client = chromadb.PersistentClient(
123
- path=chroma_db_path,
124
- settings=Settings(anonymized_telemetry=False)
125
- )
126
-
127
- embeddings = OpenAIEmbeddings(
128
- model="text-embedding-ada-002",
129
- api_key=os.environ["OPENAI_API_KEY"],
130
- disallowed_special=(),
131
- chunk_size=1000
132
- )
133
- global vector_store
134
- vector_store = Chroma(
135
- client=chroma_client,
136
- collection_name="mental_health_conversations",
137
- embedding_function=embeddings
138
- )
139
- except Exception as e:
140
- logger.error(f"Error initializing ChromaDB: {e}")
141
- raise HTTPException(status_code=500, detail=f"Error initializing ChromaDB: {e}")
142
-
143
- # Initialize OpenAI client and LLM
144
- logger.info("Initializing OpenAI client and LLM")
145
- global openai_client, llm
146
- openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
147
- llm = ChatOpenAI(
148
- model="gpt-4o-mini",
149
- temperature=0.7,
150
- api_key=os.environ["OPENAI_API_KEY"]
151
- )
152
-
153
- # Pydantic model for request
154
- class PatientContext(BaseModel):
155
- context: str
156
-
157
- # New Pydantic models for expanded API functionality
158
- class UserProfile(BaseModel):
159
- user_id: Optional[str] = None
160
- username: str
161
- name: str
162
- role: str = "counselor"
163
- specializations: List[str] = []
164
- years_experience: Optional[int] = None
165
- custom_crisis_keywords: List[str] = []
166
- preferences: Dict[str, Any] = {}
167
- created_at: Optional[datetime] = None
168
- updated_at: Optional[datetime] = None
169
-
170
- class SessionData(BaseModel):
171
- session_id: Optional[str] = None
172
- counselor_id: str
173
- patient_identifier: str # Anonymized ID
174
- session_notes: str = ""
175
- session_preferences: Dict[str, Any] = {}
176
- crisis_keywords: List[str] = []
177
- created_at: Optional[datetime] = None
178
- updated_at: Optional[datetime] = None
179
-
180
- class ConversationEntry(BaseModel):
181
- session_id: str
182
- message: str
183
- sender: str # 'patient' or 'counselor'
184
- timestamp: Optional[datetime] = None
185
- suggested_response: Optional[str] = None
186
- response_type: Optional[str] = None
187
- crisis_flag: bool = False
188
- risk_level: Optional[str] = None
189
-
190
- class FeedbackData(BaseModel):
191
- suggestion_id: str
192
- counselor_id: str
193
- rating: int # 1-5 scale
194
- was_effective: bool
195
- comments: Optional[str] = None
196
-
197
- class AnalysisRequest(BaseModel):
198
- text: str
199
- patient_background: Optional[Dict[str, Any]] = None
200
- patient_age: Optional[int] = None
201
- cultural_context: Optional[str] = None
202
-
203
- class MultiModalInput(BaseModel):
204
- session_id: str
205
- counselor_id: str
206
- input_type: str # 'text', 'audio', 'video'
207
- content: str # Text content or file path/url
208
- metadata: Dict[str, Any] = {}
209
-
210
- class InterventionRequest(BaseModel):
211
- patient_issue: str
212
- patient_background: Optional[Dict[str, Any]] = None
213
- intervention_type: Optional[str] = None # e.g., 'CBT', 'DBT', 'mindfulness'
214
-
215
- # Text preprocessing function
216
- @traceable(run_type="tool", name="Clean Text")
217
- def clean_text(text):
218
- if pd.isna(text):
219
- return ""
220
- text = str(text).lower()
221
- text = re.sub(r"[^a-zA-Z']", " ", text)
222
- tokens = word_tokenize(text)
223
- tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
224
- return " ".join(tokens)
225
-
226
- # Feature engineering function
227
- @traceable(run_type="tool", name="Engineer Features")
228
- def engineer_features(context, response=""):
229
- context_clean = clean_text(context)
230
- context_len = len(context_clean.split())
231
- context_vader = analyzer.polarity_scores(context)['compound']
232
- context_questions = context.count('?')
233
- crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
234
- context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
235
-
236
- context_tfidf = vectorizer.transform([context_clean]).toarray()
237
- tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
238
- response_tfidf = np.zeros_like(context_tfidf)
239
-
240
- lda_topics = lda.transform(context_tfidf)
241
-
242
- feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
243
- [f"topic_{i}" for i in range(10)] + tfidf_cols + \
244
- [f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
245
-
246
- features = pd.DataFrame({
247
- "context_len": [context_len],
248
- "context_vader": [context_vader],
249
- "context_questions": [context_questions],
250
- **{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
251
- **{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
252
- **{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
253
- })
254
-
255
- crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
256
- crisis_flag = crisis_clf.predict(crisis_features)[0]
257
- if context_crisis_score > 0:
258
- crisis_flag = 1
259
- features["crisis_flag"] = crisis_flag
260
-
261
- return features, feature_cols
262
-
263
- # Prediction function
264
- @traceable(run_type="chain", name="Predict Response Type")
265
- def predict_response_type(context):
266
- features, feature_cols = engineer_features(context)
267
- selected_features = selector.transform(features[feature_cols])
268
- pred_encoded = response_clf.predict(selected_features)[0]
269
- pred_label = le.inverse_transform([pred_encoded])[0]
270
- confidence = response_clf.predict_proba(selected_features)[0].max()
271
-
272
- if "?" in context and context.count("?") > 0:
273
- pred_label = "Question"
274
- if "trying" in context.lower() and "hard" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
275
- pred_label = "Validation"
276
- if "trying" in context.lower() and "positive" in context.lower() and not any(kw in context.lower() for kw in ["how", "what", "help"]):
277
- pred_label = "Question"
278
-
279
- crisis_flag = bool(features["crisis_flag"].iloc[0])
280
-
281
- return {
282
- "response_type": pred_label,
283
- "crisis_flag": crisis_flag,
284
- "confidence": confidence,
285
- "features": features.to_dict()
286
- }
287
-
288
- # RAG suggestion function
289
- @traceable(run_type="chain", name="RAG Suggestion")
290
- def generate_suggestion_rag(context, response_type, crisis_flag):
291
- results = vector_store.similarity_search_with_score(context, k=3)
292
- retrieved_contexts = [
293
- f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
294
- for res in results
295
- ]
296
-
297
- prompt_template = ChatPromptTemplate.from_template(
298
- """
299
- You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
300
-
301
- Patient Situation: {context}
302
-
303
- Predicted Response Type: {response_type}
304
- Crisis Flag: {crisis_flag}
305
-
306
- Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
307
-
308
- For reference, here are similar cases from past conversations:
309
- {retrieved_contexts}
310
-
311
- Guidelines:
312
- - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
313
- - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
314
- - For 'Advice', provide practical, actionable suggestions.
315
- - For 'Question', pose an open-ended question to encourage further discussion.
316
- - For 'Validation', affirm the patient's efforts or feelings.
317
-
318
- Output in the following format:
319
- ```json
320
- {{
321
- "suggested_response": "Your suggested response here",
322
- "risk_level": "Low/Moderate/High"
323
- }}
324
- ```
325
- """
326
- )
327
-
328
- rag_chain = (
329
- {
330
- "context": RunnablePassthrough(),
331
- "response_type": lambda x: response_type,
332
- "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
333
- "retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
334
- }
335
- | prompt_template
336
- | llm
337
- )
338
-
339
- try:
340
- response = rag_chain.invoke(context)
341
- return eval(response.content.strip("```json\n").strip("\n```"))
342
- except Exception as e:
343
- logger.error(f"Error generating RAG suggestion: {e}")
344
- raise HTTPException(status_code=500, detail=f"Error generating RAG suggestion: {str(e)}")
345
-
346
- # Direct suggestion function
347
- @traceable(run_type="chain", name="Direct Suggestion")
348
- def generate_suggestion_direct(context, response_type, crisis_flag):
349
- prompt_template = ChatPromptTemplate.from_template(
350
- """
351
- You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
352
-
353
- Patient Situation: {context}
354
-
355
- Predicted Response Type: {response_type}
356
- Crisis Flag: {crisis_flag}
357
-
358
- Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
359
-
360
- Guidelines:
361
- - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
362
- - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
363
- - For 'Advice', provide practical, actionable suggestions.
364
- - For 'Question', pose an open-ended question to encourage further discussion.
365
- - For 'Validation', affirm the patient's efforts or feelings.
366
- - Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
367
-
368
- Output in the following format:
369
- ```json
370
- {{
371
- "suggested_response": "Your suggested response here",
372
- "risk_level": "Low/Moderate/High"
373
- }}
374
- ```
375
- """
376
- )
377
-
378
- direct_chain = (
379
- {
380
- "context": RunnablePassthrough(),
381
- "response_type": lambda x: response_type,
382
- "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
383
- }
384
- | prompt_template
385
- | llm
386
- )
387
-
388
- try:
389
- response = direct_chain.invoke(context)
390
- return eval(response.content.strip("```json\n").strip("\n```"))
391
- except Exception as e:
392
- logger.error(f"Error generating direct suggestion: {e}")
393
- raise HTTPException(status_code=500, detail=f"Error generating direct suggestion: {str(e)}")
394
-
395
- # User Profile Endpoints
396
- @app.post("/users/create", response_model=UserProfile)
397
- async def create_user(profile: UserProfile):
398
- """Create a new counselor profile with preferences and specializations."""
399
- try:
400
- saved_profile = save_user_profile(profile)
401
- logger.info(f"Created user profile: {saved_profile.user_id}")
402
- return saved_profile
403
- except Exception as e:
404
- logger.error(f"Error creating user profile: {e}")
405
- raise HTTPException(status_code=500, detail=f"Error creating user profile: {str(e)}")
406
-
407
- @app.get("/users/{user_id}", response_model=UserProfile)
408
- async def get_user(user_id: str):
409
- """Get a counselor profile by user ID."""
410
- profile = get_user_profile(user_id)
411
- if not profile:
412
- raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
413
- return profile
414
-
415
- @app.put("/users/{user_id}", response_model=UserProfile)
416
- async def update_user(user_id: str, profile_update: UserProfile):
417
- """Update a counselor profile."""
418
- existing_profile = get_user_profile(user_id)
419
- if not existing_profile:
420
- raise HTTPException(status_code=404, detail=f"User profile not found: {user_id}")
421
-
422
- # Preserve the original user_id
423
- profile_update.user_id = user_id
424
- # Preserve the original created_at timestamp
425
- profile_update.created_at = existing_profile.created_at
426
-
427
- try:
428
- updated_profile = save_user_profile(profile_update)
429
- logger.info(f"Updated user profile: {user_id}")
430
- return updated_profile
431
- except Exception as e:
432
- logger.error(f"Error updating user profile: {e}")
433
- raise HTTPException(status_code=500, detail=f"Error updating user profile: {str(e)}")
434
-
435
- # Session Management Endpoints
436
- @app.post("/sessions/create", response_model=SessionData)
437
- async def create_session(session_data: SessionData):
438
- """Create a new session with patient identifier (anonymized)."""
439
- try:
440
- # Verify counselor exists
441
- counselor = get_user_profile(session_data.counselor_id)
442
- if not counselor:
443
- raise HTTPException(status_code=404, detail=f"Counselor not found: {session_data.counselor_id}")
444
-
445
- # If counselor has custom crisis keywords, add them to the session
446
- if counselor.custom_crisis_keywords:
447
- session_data.crisis_keywords.extend(counselor.custom_crisis_keywords)
448
-
449
- saved_session = save_session(session_data)
450
- logger.info(f"Created session: {saved_session.session_id}")
451
- return saved_session
452
- except HTTPException:
453
- raise
454
- except Exception as e:
455
- logger.error(f"Error creating session: {e}")
456
- raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}")
457
-
458
- @app.get("/sessions/{session_id}", response_model=SessionData)
459
- async def get_session_by_id(session_id: str):
460
- """Get a session by ID."""
461
- session = get_session(session_id)
462
- if not session:
463
- raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
464
- return session
465
-
466
- @app.get("/sessions/counselor/{counselor_id}", response_model=List[SessionData])
467
- async def get_counselor_sessions(counselor_id: str):
468
- """Get all sessions for a counselor."""
469
- sessions = get_user_sessions(counselor_id)
470
- return sessions
471
-
472
- @app.put("/sessions/{session_id}", response_model=SessionData)
473
- async def update_session(session_id: str, session_update: SessionData):
474
- """Update a session."""
475
- existing_session = get_session(session_id)
476
- if not existing_session:
477
- raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
478
-
479
- # Preserve the original session_id and created_at
480
- session_update.session_id = session_id
481
- session_update.created_at = existing_session.created_at
482
-
483
- try:
484
- updated_session = save_session(session_update)
485
- logger.info(f"Updated session: {session_id}")
486
- return updated_session
487
- except Exception as e:
488
- logger.error(f"Error updating session: {e}")
489
- raise HTTPException(status_code=500, detail=f"Error updating session: {str(e)}")
490
-
491
- # Conversation History Endpoints
492
- @app.post("/conversations/add", response_model=str)
493
- async def add_conversation_entry(entry: ConversationEntry):
494
- """Add a new entry to a conversation."""
495
- try:
496
- # Verify session exists
497
- session = get_session(entry.session_id)
498
- if not session:
499
- raise HTTPException(status_code=404, detail=f"Session not found: {entry.session_id}")
500
-
501
- entry_id = save_conversation_entry(entry)
502
- logger.info(f"Added conversation entry: {entry_id}")
503
- return entry_id
504
- except HTTPException:
505
- raise
506
- except Exception as e:
507
- logger.error(f"Error adding conversation entry: {e}")
508
- raise HTTPException(status_code=500, detail=f"Error adding conversation entry: {str(e)}")
509
-
510
- @app.get("/conversations/{session_id}", response_model=List[ConversationEntry])
511
- async def get_conversation(session_id: str):
512
- """Get conversation history for a session."""
513
- try:
514
- # Verify session exists
515
- session = get_session(session_id)
516
- if not session:
517
- raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
518
-
519
- entries = get_conversation_history(session_id)
520
- return entries
521
- except HTTPException:
522
- raise
523
- except Exception as e:
524
- logger.error(f"Error retrieving conversation history: {e}")
525
- raise HTTPException(status_code=500, detail=f"Error retrieving conversation history: {str(e)}")
526
-
527
- # API Endpoints
528
- @app.post("/suggest")
529
- async def get_suggestion(context: PatientContext):
530
- logger.info(f"Received suggestion request for context: {context.context}")
531
- prediction = predict_response_type(context.context)
532
- suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
533
- suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
534
-
535
  return {
536
- "context": context.context,
537
- "response_type": prediction["response_type"],
538
- "crisis_flag": prediction["crisis_flag"],
539
- "confidence": prediction["confidence"],
540
- "rag_suggestion": suggestion_rag["suggested_response"],
541
- "rag_risk_level": suggestion_rag["risk_level"],
542
- "direct_suggestion": suggestion_direct["suggested_response"],
543
- "direct_risk_level": suggestion_direct["risk_level"]
544
  }
545
 
546
- @app.post("/session/suggest")
547
- async def get_session_suggestion(request: dict):
548
- """Get suggestion within a session context, with enhanced crisis detection based on session keywords."""
549
- try:
550
- session_id = request.get("session_id")
551
- if not session_id:
552
- raise HTTPException(status_code=400, detail="session_id is required")
553
-
554
- context = request.get("context")
555
- if not context:
556
- raise HTTPException(status_code=400, detail="context is required")
557
-
558
- # Get session for custom crisis keywords
559
- session = get_session(session_id)
560
- if not session:
561
- raise HTTPException(status_code=404, detail=f"Session not found: {session_id}")
562
-
563
- # Get conversation history for context
564
- conversation_history = get_conversation_history(session_id)
565
-
566
- # Regular prediction
567
- prediction = predict_response_type(context)
568
- crisis_flag = prediction["crisis_flag"]
569
-
570
- # Enhanced crisis detection with custom keywords
571
- if not crisis_flag and session.crisis_keywords:
572
- for keyword in session.crisis_keywords:
573
- if keyword.lower() in context.lower():
574
- crisis_flag = True
575
- logger.info(f"Crisis flag triggered by custom keyword: {keyword}")
576
- break
577
-
578
- # Generate suggestions
579
- suggestion_rag = generate_suggestion_rag(context, prediction["response_type"], crisis_flag)
580
- suggestion_direct = generate_suggestion_direct(context, prediction["response_type"], crisis_flag)
581
-
582
- # Create response
583
- response = {
584
- "context": context,
585
- "response_type": prediction["response_type"],
586
- "crisis_flag": crisis_flag,
587
- "confidence": prediction["confidence"],
588
- "rag_suggestion": suggestion_rag["suggested_response"],
589
- "rag_risk_level": suggestion_rag["risk_level"],
590
- "direct_suggestion": suggestion_direct["suggested_response"],
591
- "direct_risk_level": suggestion_direct["risk_level"],
592
- "session_id": session_id
593
- }
594
-
595
- # Save the conversation entry
596
- entry = ConversationEntry(
597
- session_id=session_id,
598
- message=context,
599
- sender="patient",
600
- suggested_response=suggestion_rag["suggested_response"],
601
- response_type=prediction["response_type"],
602
- crisis_flag=crisis_flag,
603
- risk_level=suggestion_rag["risk_level"]
604
- )
605
- save_conversation_entry(entry)
606
-
607
- return response
608
- except HTTPException:
609
- raise
610
- except Exception as e:
611
- logger.error(f"Error getting session suggestion: {e}")
612
- raise HTTPException(status_code=500, detail=f"Error getting session suggestion: {str(e)}")
613
-
614
- # Feedback Endpoints
615
- @app.post("/feedback")
616
- async def add_feedback(feedback: FeedbackData):
617
- """Add feedback about a suggestion's effectiveness."""
618
- try:
619
- feedback_id = save_feedback(feedback)
620
- logger.info(f"Added feedback: {feedback_id}")
621
- return {"feedback_id": feedback_id}
622
- except Exception as e:
623
- logger.error(f"Error adding feedback: {e}")
624
- raise HTTPException(status_code=500, detail=f"Error adding feedback: {str(e)}")
625
-
626
- # Tone & Cultural Sensitivity Analysis
627
- @traceable(run_type="chain", name="Cultural Sensitivity Analysis")
628
- def analyze_cultural_sensitivity(text: str, cultural_context: Optional[str] = None):
629
- """Analyze text for cultural appropriateness and sensitivity."""
630
- prompt_template = ChatPromptTemplate.from_template(
631
- """
632
- You are a cultural sensitivity expert. Analyze the following text for cultural appropriateness:
633
-
634
- Text: {text}
635
-
636
- Cultural Context: {cultural_context}
637
-
638
- Provide an analysis of:
639
- 1. Cultural appropriateness
640
- 2. Potential bias or insensitivity
641
- 3. Suggestions for improvement
642
-
643
- Output in the following format:
644
- ```json
645
- {{
646
- "cultural_appropriateness_score": 0-10,
647
- "issues_detected": ["issue1", "issue2"],
648
- "suggestions": ["suggestion1", "suggestion2"],
649
- "explanation": "Brief explanation of analysis"
650
- }}
651
- ```
652
- """
653
- )
654
-
655
- analysis_chain = (
656
- {
657
- "text": RunnablePassthrough(),
658
- "cultural_context": lambda x: cultural_context if cultural_context else "General"
659
- }
660
- | prompt_template
661
- | llm
662
- )
663
-
664
- try:
665
- response = analysis_chain.invoke(text)
666
- return eval(response.content.strip("```json\n").strip("\n```"))
667
- except Exception as e:
668
- logger.error(f"Error analyzing cultural sensitivity: {e}")
669
- raise HTTPException(status_code=500, detail=f"Error analyzing cultural sensitivity: {str(e)}")
670
-
671
- @traceable(run_type="chain", name="Age Appropriate Analysis")
672
- def analyze_age_appropriateness(text: str, age: Optional[int] = None):
673
- """Analyze text for age-appropriate language."""
674
- prompt_template = ChatPromptTemplate.from_template(
675
- """
676
- You are an expert in age-appropriate communication. Analyze the following text for age appropriateness:
677
-
678
- Text: {text}
679
-
680
- Target Age: {age}
681
-
682
- Provide an analysis of:
683
- 1. Age appropriateness
684
- 2. Complexity level
685
- 3. Suggestions for improvement
686
-
687
- Output in the following format:
688
- ```json
689
- {{
690
- "age_appropriateness_score": 0-10,
691
- "complexity_level": "Simple/Moderate/Complex",
692
- "issues_detected": ["issue1", "issue2"],
693
- "suggestions": ["suggestion1", "suggestion2"],
694
- "explanation": "Brief explanation of analysis"
695
- }}
696
- ```
697
- """
698
- )
699
-
700
- analysis_chain = (
701
- {
702
- "text": RunnablePassthrough(),
703
- "age": lambda x: str(age) if age else "Adult"
704
- }
705
- | prompt_template
706
- | llm
707
- )
708
-
709
- try:
710
- response = analysis_chain.invoke(text)
711
- return eval(response.content.strip("```json\n").strip("\n```"))
712
- except Exception as e:
713
- logger.error(f"Error analyzing age appropriateness: {e}")
714
- raise HTTPException(status_code=500, detail=f"Error analyzing age appropriateness: {str(e)}")
715
-
716
- @app.post("/analyze/sensitivity")
717
- async def analyze_text_sensitivity(request: AnalysisRequest):
718
- """Analyze text for cultural sensitivity and age appropriateness."""
719
- try:
720
- cultural_analysis = analyze_cultural_sensitivity(request.text, request.cultural_context)
721
- age_analysis = analyze_age_appropriateness(request.text, request.patient_age)
722
-
723
- return {
724
- "text": request.text,
725
- "cultural_analysis": cultural_analysis,
726
- "age_analysis": age_analysis
727
- }
728
- except Exception as e:
729
- logger.error(f"Error analyzing text sensitivity: {e}")
730
- raise HTTPException(status_code=500, detail=f"Error analyzing text sensitivity: {str(e)}")
731
-
732
- # Guided Intervention Workflows
733
- @traceable(run_type="chain", name="Generate Intervention")
734
- def generate_intervention_workflow(issue: str, intervention_type: Optional[str] = None, background: Optional[Dict] = None):
735
- """Generate a structured intervention workflow for a specific issue."""
736
- prompt_template = ChatPromptTemplate.from_template(
737
- """
738
- You are an expert mental health counselor. Generate a structured intervention workflow for the following patient issue:
739
-
740
- Patient Issue: {issue}
741
-
742
- Intervention Type: {intervention_type}
743
- Patient Background: {background}
744
-
745
- Provide a step-by-step intervention plan based on evidence-based practices. Include:
746
- 1. Initial assessment questions
747
- 2. Specific techniques to apply
748
- 3. Homework or practice exercises
749
- 4. Follow-up guidance
750
-
751
- Output in the following format:
752
- ```json
753
- {{
754
- "intervention_type": "CBT/DBT/ACT/Mindfulness/etc.",
755
- "assessment_questions": ["question1", "question2", "question3"],
756
- "techniques": [
757
- {{
758
- "name": "technique name",
759
- "description": "brief description",
760
- "instructions": "step-by-step instructions"
761
- }}
762
- ],
763
- "exercises": [
764
- {{
765
- "name": "exercise name",
766
- "description": "brief description",
767
- "instructions": "step-by-step instructions"
768
- }}
769
- ],
770
- "follow_up": ["follow-up step 1", "follow-up step 2"],
771
- "resources": ["resource1", "resource2"]
772
- }}
773
- ```
774
- """
775
- )
776
-
777
- intervention_chain = (
778
- {
779
- "issue": RunnablePassthrough(),
780
- "intervention_type": lambda x: intervention_type if intervention_type else "Best fit",
781
- "background": lambda x: str(background) if background else "Not provided"
782
- }
783
- | prompt_template
784
- | llm
785
- )
786
-
787
- try:
788
- response = intervention_chain.invoke(issue)
789
- return eval(response.content.strip("```json\n").strip("\n```"))
790
- except Exception as e:
791
- logger.error(f"Error generating intervention workflow: {e}")
792
- raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
793
-
794
- @app.post("/interventions/generate")
795
- async def get_intervention_workflow(request: InterventionRequest):
796
- """Get a structured intervention workflow for a specific patient issue."""
797
- try:
798
- intervention = generate_intervention_workflow(
799
- request.patient_issue,
800
- request.intervention_type,
801
- request.patient_background
802
- )
803
-
804
- return {
805
- "patient_issue": request.patient_issue,
806
- "intervention": intervention
807
- }
808
- except Exception as e:
809
- logger.error(f"Error generating intervention workflow: {e}")
810
- raise HTTPException(status_code=500, detail=f"Error generating intervention workflow: {str(e)}")
811
-
812
  @app.get("/health")
813
  async def health_check():
814
- if all([response_clf, crisis_clf, vectorizer, le, selector, lda, vector_store, llm]):
815
- return {"status": "healthy", "message": "All models and vector store loaded successfully"}
816
- logger.error("Health check failed: One or more components not loaded")
817
- raise HTTPException(status_code=500, detail="One or more components failed to load")
818
 
 
819
  @app.get("/metadata")
820
  async def get_metadata():
821
- try:
822
- collection = vector_store._client.get_collection("mental_health_conversations")
823
- count = collection.count()
824
- return {"collection_name": "mental_health_conversations", "document_count": count}
825
- except Exception as e:
826
- logger.error(f"Error retrieving metadata: {e}")
827
- raise HTTPException(status_code=500, detail=f"Error retrieving metadata: {str(e)}")
828
-
829
- # Database utility functions
830
- def save_user_profile(profile: UserProfile):
831
- if not profile.user_id:
832
- profile.user_id = str(uuid4())
833
-
834
- if not profile.created_at:
835
- profile.created_at = datetime.now()
836
-
837
- profile.updated_at = datetime.now()
838
-
839
- file_path = os.path.join(DATA_DIR, "users", f"{profile.user_id}.json")
840
- with open(file_path, "w") as f:
841
- # Convert datetime to string for JSON serialization
842
- profile_dict = profile.dict()
843
- for key in ["created_at", "updated_at"]:
844
- if profile_dict[key]:
845
- profile_dict[key] = profile_dict[key].isoformat()
846
- f.write(json.dumps(profile_dict, indent=2))
847
-
848
- return profile
849
-
850
- def get_user_profile(user_id: str) -> Optional[UserProfile]:
851
- file_path = os.path.join(DATA_DIR, "users", f"{user_id}.json")
852
- if not os.path.exists(file_path):
853
- return None
854
-
855
- with open(file_path, "r") as f:
856
- data = json.loads(f.read())
857
- # Convert string dates back to datetime
858
- for key in ["created_at", "updated_at"]:
859
- if data[key]:
860
- data[key] = datetime.fromisoformat(data[key])
861
- return UserProfile(**data)
862
-
863
- def save_session(session: SessionData):
864
- if not session.session_id:
865
- session.session_id = str(uuid4())
866
-
867
- if not session.created_at:
868
- session.created_at = datetime.now()
869
-
870
- session.updated_at = datetime.now()
871
-
872
- file_path = os.path.join(DATA_DIR, "sessions", f"{session.session_id}.json")
873
- with open(file_path, "w") as f:
874
- # Convert datetime to string for JSON serialization
875
- session_dict = session.dict()
876
- for key in ["created_at", "updated_at"]:
877
- if session_dict[key]:
878
- session_dict[key] = session_dict[key].isoformat()
879
- f.write(json.dumps(session_dict, indent=2))
880
-
881
- return session
882
-
883
- def get_session(session_id: str) -> Optional[SessionData]:
884
- file_path = os.path.join(DATA_DIR, "sessions", f"{session_id}.json")
885
- if not os.path.exists(file_path):
886
- return None
887
-
888
- with open(file_path, "r") as f:
889
- data = json.loads(f.read())
890
- # Convert string dates back to datetime
891
- for key in ["created_at", "updated_at"]:
892
- if data[key]:
893
- data[key] = datetime.fromisoformat(data[key])
894
- return SessionData(**data)
895
-
896
- def get_user_sessions(counselor_id: str) -> List[SessionData]:
897
- sessions = []
898
- sessions_dir = os.path.join(DATA_DIR, "sessions")
899
- for filename in os.listdir(sessions_dir):
900
- if not filename.endswith(".json"):
901
- continue
902
-
903
- file_path = os.path.join(sessions_dir, filename)
904
- with open(file_path, "r") as f:
905
- data = json.loads(f.read())
906
- if data["counselor_id"] == counselor_id:
907
- for key in ["created_at", "updated_at"]:
908
- if data[key]:
909
- data[key] = datetime.fromisoformat(data[key])
910
- sessions.append(SessionData(**data))
911
-
912
- return sessions
913
-
914
- def save_conversation_entry(entry: ConversationEntry):
915
- conversation_dir = os.path.join(DATA_DIR, "conversations", entry.session_id)
916
- os.makedirs(conversation_dir, exist_ok=True)
917
-
918
- if not entry.timestamp:
919
- entry.timestamp = datetime.now()
920
-
921
- entry_id = str(uuid4())
922
- file_path = os.path.join(conversation_dir, f"{entry_id}.json")
923
-
924
- with open(file_path, "w") as f:
925
- # Convert datetime to string for JSON serialization
926
- entry_dict = entry.dict()
927
- entry_dict["entry_id"] = entry_id
928
- if entry_dict["timestamp"]:
929
- entry_dict["timestamp"] = entry_dict["timestamp"].isoformat()
930
- f.write(json.dumps(entry_dict, indent=2))
931
-
932
- return entry_id
933
-
934
- def get_conversation_history(session_id: str) -> List[ConversationEntry]:
935
- conversation_dir = os.path.join(DATA_DIR, "conversations", session_id)
936
- if not os.path.exists(conversation_dir):
937
- return []
938
-
939
- entries = []
940
- for filename in os.listdir(conversation_dir):
941
- if not filename.endswith(".json"):
942
- continue
943
-
944
- file_path = os.path.join(conversation_dir, filename)
945
- with open(file_path, "r") as f:
946
- data = json.loads(f.read())
947
- if data["timestamp"]:
948
- data["timestamp"] = datetime.fromisoformat(data["timestamp"])
949
- entries.append(ConversationEntry(**data))
950
-
951
- # Sort by timestamp
952
- entries.sort(key=lambda x: x.timestamp)
953
- return entries
954
-
955
- def save_feedback(feedback: FeedbackData):
956
- feedback_id = str(uuid4())
957
- file_path = os.path.join(DATA_DIR, "feedback", f"{feedback_id}.json")
958
-
959
- with open(file_path, "w") as f:
960
- feedback_dict = feedback.dict()
961
- feedback_dict["feedback_id"] = feedback_id
962
- feedback_dict["timestamp"] = datetime.now().isoformat()
963
- f.write(json.dumps(feedback_dict, indent=2))
964
-
965
- return feedback_id
966
-
967
- # Multi-modal Input Support
968
- @app.post("/input/process")
969
- async def process_multimodal_input(input_data: MultiModalInput):
970
- """Process multi-modal input (text, audio, video)."""
971
- try:
972
- if input_data.input_type not in ["text", "audio", "video"]:
973
- raise HTTPException(status_code=400, detail=f"Unsupported input type: {input_data.input_type}")
974
-
975
- # For now, handle text directly and simulate processing for audio/video
976
- if input_data.input_type == "text":
977
- # Process text normally
978
- prediction = predict_response_type(input_data.content)
979
-
980
- return {
981
- "input_type": "text",
982
- "processed_content": input_data.content,
983
- "analysis": {
984
- "response_type": prediction["response_type"],
985
- "crisis_flag": prediction["crisis_flag"],
986
- "confidence": prediction["confidence"]
987
- }
988
- }
989
- elif input_data.input_type == "audio":
990
- # Simulate audio transcription and emotion detection
991
- # In a production system, this would use a speech-to-text API and emotion analysis
992
- prompt_template = ChatPromptTemplate.from_template(
993
- """
994
- Simulate audio processing for this description: {content}
995
-
996
- Generate a simulated transcription and emotion detection as if this were real audio.
997
-
998
- Output in the following format:
999
- ```json
1000
- {{
1001
- "transcription": "Simulated transcription of the audio",
1002
- "emotion_detected": "primary emotion",
1003
- "secondary_emotions": ["emotion1", "emotion2"],
1004
- "confidence": 0.85
1005
- }}
1006
- ```
1007
- """
1008
- )
1009
-
1010
- process_chain = prompt_template | llm
1011
- response = process_chain.invoke({"content": input_data.content})
1012
- audio_result = eval(response.content.strip("```json\n").strip("\n```"))
1013
-
1014
- # Now process the transcription
1015
- prediction = predict_response_type(audio_result["transcription"])
1016
-
1017
- return {
1018
- "input_type": "audio",
1019
- "processed_content": audio_result["transcription"],
1020
- "emotion_analysis": {
1021
- "primary_emotion": audio_result["emotion_detected"],
1022
- "secondary_emotions": audio_result["secondary_emotions"],
1023
- "confidence": audio_result["confidence"]
1024
- },
1025
- "analysis": {
1026
- "response_type": prediction["response_type"],
1027
- "crisis_flag": prediction["crisis_flag"],
1028
- "confidence": prediction["confidence"]
1029
- }
1030
- }
1031
- elif input_data.input_type == "video":
1032
- # Simulate video analysis
1033
- # In a production system, this would use video analytics API
1034
- prompt_template = ChatPromptTemplate.from_template(
1035
- """
1036
- Simulate video processing for this description: {content}
1037
-
1038
- Generate a simulated analysis as if this were real video with facial expressions and body language.
1039
-
1040
- Output in the following format:
1041
- ```json
1042
- {{
1043
- "transcription": "Simulated transcription of speech in the video",
1044
- "facial_expressions": ["expression1", "expression2"],
1045
- "body_language": ["posture observation", "gesture observation"],
1046
- "primary_emotion": "primary emotion",
1047
- "confidence": 0.80
1048
- }}
1049
- ```
1050
- """
1051
- )
1052
-
1053
- process_chain = prompt_template | llm
1054
- response = process_chain.invoke({"content": input_data.content})
1055
- video_result = eval(response.content.strip("```json\n").strip("\n```"))
1056
-
1057
- # Process the transcription
1058
- prediction = predict_response_type(video_result["transcription"])
1059
-
1060
- return {
1061
- "input_type": "video",
1062
- "processed_content": video_result["transcription"],
1063
- "nonverbal_analysis": {
1064
- "facial_expressions": video_result["facial_expressions"],
1065
- "body_language": video_result["body_language"],
1066
- "primary_emotion": video_result["primary_emotion"],
1067
- "confidence": video_result["confidence"]
1068
- },
1069
- "analysis": {
1070
- "response_type": prediction["response_type"],
1071
- "crisis_flag": prediction["crisis_flag"],
1072
- "confidence": prediction["confidence"]
1073
- }
1074
- }
1075
- except Exception as e:
1076
- logger.error(f"Error processing multimodal input: {e}")
1077
- raise HTTPException(status_code=500, detail=f"Error processing multimodal input: {str(e)}")
1078
 
1079
- # Therapeutic Technique Suggestions
1080
- @traceable(run_type="chain", name="Therapeutic Techniques")
1081
- def suggest_therapeutic_techniques(context: str, technique_type: Optional[str] = None):
1082
- """Suggest specific therapeutic techniques based on the patient context."""
1083
- prompt_template = ChatPromptTemplate.from_template(
1084
- """
1085
- You are an expert mental health professional with extensive knowledge of therapeutic techniques. Based on the following patient context, suggest therapeutic techniques that would be appropriate:
1086
-
1087
- Patient Context: {context}
1088
-
1089
- Technique Type (if specified): {technique_type}
1090
-
1091
- Suggest specific therapeutic techniques, exercises, or interventions that would be helpful for this patient. Include:
1092
- 1. Name of technique
1093
- 2. Brief description
1094
- 3. How to apply it in this specific case
1095
- 4. Expected benefits
1096
-
1097
- Provide a range of options from different therapeutic approaches (CBT, DBT, ACT, mindfulness, motivational interviewing, etc.) unless a specific technique type was requested.
1098
-
1099
- Output in the following format:
1100
- ```json
1101
- {{
1102
- "primary_approach": "The most appropriate therapeutic approach",
1103
- "techniques": [
1104
- {{
1105
- "name": "Technique name",
1106
- "approach": "CBT/DBT/ACT/etc.",
1107
- "description": "Brief description",
1108
- "application": "How to apply to this specific case",
1109
- "benefits": "Expected benefits"
1110
- }}
1111
- ],
1112
- "rationale": "Brief explanation of why these techniques were selected"
1113
- }}
1114
- ```
1115
- """
1116
- )
1117
 
1118
- technique_chain = (
1119
- {
1120
- "context": RunnablePassthrough(),
1121
- "technique_type": lambda x: technique_type if technique_type else "Any appropriate"
1122
- }
1123
- | prompt_template
1124
- | llm
1125
- )
1126
 
1127
- try:
1128
- response = technique_chain.invoke(context)
1129
- return eval(response.content.strip("```json\n").strip("\n```"))
1130
- except Exception as e:
1131
- logger.error(f"Error suggesting therapeutic techniques: {e}")
1132
- raise HTTPException(status_code=500, detail=f"Error suggesting therapeutic techniques: {str(e)}")
1133
-
1134
- @app.post("/techniques/suggest")
1135
- async def get_therapeutic_techniques(request: dict):
1136
- """Get suggested therapeutic techniques for a patient context."""
1137
- try:
1138
- context = request.get("context")
1139
- if not context:
1140
- raise HTTPException(status_code=400, detail="context is required")
1141
-
1142
- technique_type = request.get("technique_type")
1143
-
1144
- techniques = suggest_therapeutic_techniques(context, technique_type)
1145
-
1146
  return {
1147
- "context": context,
1148
- "techniques": techniques
 
1149
  }
1150
- except Exception as e:
1151
- logger.error(f"Error getting therapeutic techniques: {e}")
1152
- raise HTTPException(status_code=500, detail=f"Error getting therapeutic techniques: {str(e)}")
1153
-
1154
- # Ethical AI Guardrails - Confidence Indicator
1155
- @app.post("/suggest/with_confidence")
1156
- async def get_suggestion_with_confidence(context: PatientContext):
1157
- """Get suggestion with detailed confidence indicators and uncertainty flags."""
1158
- try:
1159
- # Get standard prediction
1160
- prediction = predict_response_type(context.context)
1161
-
1162
- # Set confidence thresholds
1163
- high_confidence = 0.8
1164
- medium_confidence = 0.6
1165
-
1166
- # Determine confidence level
1167
- confidence_value = prediction["confidence"]
1168
- if confidence_value >= high_confidence:
1169
- confidence_level = "High"
1170
- elif confidence_value >= medium_confidence:
1171
- confidence_level = "Medium"
1172
- else:
1173
- confidence_level = "Low"
1174
-
1175
- # Analyze for potential biases
1176
- bias_prompt = ChatPromptTemplate.from_template(
1177
- """
1178
- You are an AI ethics expert. Analyze the following patient context and proposed response type for potential biases:
1179
-
1180
- Patient Context: {context}
1181
- Predicted Response Type: {response_type}
1182
-
1183
- Identify any potential biases in interpretation or response. Consider gender, cultural, socioeconomic, and other potential biases.
1184
-
1185
- Output in the following format:
1186
- ```json
1187
- {{
1188
- "bias_detected": true/false,
1189
- "bias_types": ["bias type 1", "bias type 2"],
1190
- "explanation": "Brief explanation of potential biases"
1191
- }}
1192
- ```
1193
- """
1194
- )
1195
-
1196
- bias_chain = (
1197
- {
1198
- "context": lambda x: context.context,
1199
- "response_type": lambda x: prediction["response_type"]
1200
- }
1201
- | bias_prompt
1202
- | llm
1203
- )
1204
-
1205
- bias_analysis = eval(bias_chain.invoke({}).content.strip("```json\n").strip("\n```"))
1206
-
1207
- # Generate suggestions
1208
- suggestion_rag = generate_suggestion_rag(context.context, prediction["response_type"], prediction["crisis_flag"])
1209
- suggestion_direct = generate_suggestion_direct(context.context, prediction["response_type"], prediction["crisis_flag"])
1210
-
1211
- return {
1212
- "context": context.context,
1213
- "response_type": prediction["response_type"],
1214
- "crisis_flag": prediction["crisis_flag"],
1215
- "confidence": {
1216
- "value": prediction["confidence"],
1217
- "level": confidence_level,
1218
- "uncertainty_flag": confidence_level == "Low"
1219
- },
1220
- "bias_analysis": bias_analysis,
1221
- "rag_suggestion": suggestion_rag["suggested_response"],
1222
- "rag_risk_level": suggestion_rag["risk_level"],
1223
- "direct_suggestion": suggestion_direct["suggested_response"],
1224
- "direct_risk_level": suggestion_direct["risk_level"],
1225
- "attribution": {
1226
- "ai_generated": True,
1227
- "model_version": "Mental Health Counselor API v2.0",
1228
- "human_reviewed": False
1229
- }
1230
- }
1231
- except Exception as e:
1232
- logger.error(f"Error getting suggestion with confidence: {e}")
1233
- raise HTTPException(status_code=500, detail=f"Error getting suggestion with confidence: {str(e)}")
1234
-
1235
- # Text to Speech with Eleven Labs API
1236
- @app.post("/api/text-to-speech")
1237
- async def text_to_speech(request: dict):
1238
- """Convert text to speech using Eleven Labs API."""
1239
- try:
1240
- text = request.get("text")
1241
- voice_id = request.get("voice_id", "pNInz6obpgDQGcFmaJgB") # Default to "Adam" voice
1242
-
1243
- if not text:
1244
- raise HTTPException(status_code=400, detail="Text is required")
1245
-
1246
- # Get API key from environment
1247
- api_key = os.getenv("ELEVEN_API_KEY")
1248
- if not api_key:
1249
- raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
1250
-
1251
- # Prepare the request to Eleven Labs
1252
- url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"
1253
-
1254
- headers = {
1255
- "Accept": "audio/mpeg",
1256
- "Content-Type": "application/json",
1257
- "xi-api-key": api_key
1258
- }
1259
-
1260
- payload = {
1261
- "text": text,
1262
- "model_id": "eleven_multilingual_v2",
1263
- "voice_settings": {
1264
- "stability": 0.5,
1265
- "similarity_boost": 0.75
1266
- }
1267
- }
1268
-
1269
- # Make the request to Eleven Labs
1270
- response = requests.post(url, json=payload, headers=headers)
1271
-
1272
- if response.status_code != 200:
1273
- logger.error(f"Eleven Labs API error: {response.text}")
1274
- raise HTTPException(status_code=response.status_code,
1275
- detail=f"Eleven Labs API error: {response.text}")
1276
-
1277
- # Return audio as streaming response
1278
- return StreamingResponse(
1279
- BytesIO(response.content),
1280
- media_type="audio/mpeg"
1281
- )
1282
-
1283
- except Exception as e:
1284
- logger.error(f"Error in text-to-speech: {str(e)}")
1285
- if not isinstance(e, HTTPException):
1286
- raise HTTPException(status_code=500, detail=f"Text-to-speech error: {str(e)}")
1287
- raise e
1288
-
1289
- # Multimedia file processing (speech to text)
1290
- @app.post("/api/input/process")
1291
- async def process_audio_input(
1292
- audio: UploadFile = File(...),
1293
- session_id: str = Form(...)
1294
- ):
1295
- """Process audio input for speech-to-text using Eleven Labs."""
1296
- try:
1297
- # Get API key from environment
1298
- api_key = os.getenv("ELEVEN_API_KEY")
1299
- if not api_key:
1300
- raise HTTPException(status_code=500, detail="Eleven Labs API key not configured")
1301
-
1302
- # Read the audio file content
1303
- audio_content = await audio.read()
1304
-
1305
- # Call Eleven Labs Speech-to-Text API
1306
- url = "https://api.elevenlabs.io/v1/speech-to-text"
1307
-
1308
- headers = {
1309
- "xi-api-key": api_key
1310
- }
1311
-
1312
- # Create form data with the audio file
1313
- files = {
1314
- 'audio': ('audio.webm', audio_content, 'audio/webm')
1315
- }
1316
-
1317
- data = {
1318
- 'model_id': 'whisper-1' # Using Whisper model
1319
- }
1320
-
1321
- # Make the request to Eleven Labs
1322
- response = requests.post(url, headers=headers, files=files, data=data)
1323
-
1324
- if response.status_code != 200:
1325
- logger.error(f"Eleven Labs API error: {response.text}")
1326
- raise HTTPException(status_code=response.status_code,
1327
- detail=f"Eleven Labs API error: {response.text}")
1328
-
1329
- result = response.json()
1330
-
1331
- # Extract the transcribed text
1332
- text = result.get('text', '')
1333
-
1334
- # Return the transcribed text
1335
  return {
1336
- "text": text,
1337
- "session_id": session_id
 
1338
  }
1339
-
1340
- except Exception as e:
1341
- logger.error(f"Error processing audio: {str(e)}")
1342
- if not isinstance(e, HTTPException):
1343
- raise HTTPException(status_code=500, detail=f"Audio processing error: {str(e)}")
1344
- raise e
1345
-
1346
- # Add a custom encoder for bytes objects to prevent UTF-8 decode errors
1347
- def custom_encoder(obj):
1348
- if isinstance(obj, bytes):
1349
- try:
1350
- return obj.decode('utf-8')
1351
- except UnicodeDecodeError:
1352
- return base64.b64encode(obj).decode('ascii')
1353
- raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
1354
-
1355
- # Override the jsonable_encoder function to handle bytes properly
1356
- from fastapi.encoders import jsonable_encoder as original_jsonable_encoder
1357
-
1358
- def safe_jsonable_encoder(*args, **kwargs):
1359
- try:
1360
- return original_jsonable_encoder(*args, **kwargs)
1361
- except UnicodeDecodeError:
1362
- # If the standard encoder fails with a decode error,
1363
- # ensure all bytes are properly handled
1364
- if args and isinstance(args[0], bytes):
1365
- return custom_encoder(args[0])
1366
- raise
1367
-
1368
- # Monkey patch the jsonable_encoder in FastAPI
1369
- import fastapi.encoders
1370
- fastapi.encoders.jsonable_encoder = safe_jsonable_encoder
 
1
+ from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  from dotenv import load_dotenv
 
 
 
5
  import logging
 
 
 
 
 
 
 
 
6
 
7
  # Set up logging
8
  logging.basicConfig(level=logging.INFO)
 
11
  # Load environment variables
12
  load_dotenv()
13
 
 
 
 
 
 
14
  # Initialize FastAPI app
15
  app = FastAPI(title="Mental Health Counselor API")
16
 
17
+ # Initialize global storage
18
  DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
19
  os.makedirs(DATA_DIR, exist_ok=True)
20
  os.makedirs(os.path.join(DATA_DIR, "users"), exist_ok=True)
 
22
  os.makedirs(os.path.join(DATA_DIR, "conversations"), exist_ok=True)
23
  os.makedirs(os.path.join(DATA_DIR, "feedback"), exist_ok=True)
24
 
25
+ # Simple health check route
26
+ @app.get("/")
27
+ async def root():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return {
29
+ "status": "ok",
30
+ "message": "Mental Health Counselor API is running",
31
+ "api_version": "1.0.0",
32
+ "backend_info": "FastAPI on Hugging Face Spaces"
 
 
 
 
33
  }
34
 
35
+ # Health check endpoint
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @app.get("/health")
37
  async def health_check():
38
+ return {"status": "ok", "message": "Mental Health Counselor API is running"}
 
 
 
39
 
40
+ # Metadata endpoint
41
  @app.get("/metadata")
42
  async def get_metadata():
43
+ return {
44
+ "api_version": "1.0.0",
45
+ "endpoints": [
46
+ "/",
47
+ "/health",
48
+ "/metadata"
49
+ ],
50
+ "provider": "Mental Health Counselor API on Hugging Face Spaces",
51
+ "deployment_type": "Hugging Face Spaces Docker",
52
+ "description": "This API provides functionality for a mental health counseling application.",
53
+ "frontend": "Deployed separately on Vercel"
54
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Try to import the full API if available
57
+ try:
58
+ import api_mental_health
59
+ # If the import succeeds, mount all routes from the API
60
+ logger.info("Successfully imported full API module")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Include all routes from the API
63
+ app.include_router(api_mental_health.app, prefix="")
 
 
 
 
 
 
64
 
65
+ # Add a status endpoint for the full API
66
+ @app.get("/full-api-status")
67
+ async def full_api_status():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return {
69
+ "status": "active",
70
+ "message": "Full API module is active and all routes are mounted",
71
+ "routes_count": len(api_mental_health.app.routes)
72
  }
73
+ except ImportError as e:
74
+ logger.warning(f"Could not import full API module: {e}")
75
+
76
+ @app.get("/full-api-status")
77
+ async def full_api_status():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return {
79
+ "status": "unavailable",
80
+ "message": "Full API module could not be imported",
81
+ "error": str(e)
82
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,38 +1,37 @@
1
  # Core dependencies for data processing and ML
2
- pandas>=1.3.0,<2.1.0
3
- numpy>=1.20.0,<1.25.0
4
- scikit-learn>=1.0.0,<1.3.0
5
- joblib>=1.0.0,<1.4.0
6
 
7
  # NLP and sentiment analysis
8
- nltk>=3.6.0,<3.9.0
9
- vaderSentiment>=3.3.0,<3.4.0
10
 
11
  # Dataset downloading
12
- kagglehub>=0.3.0
13
 
14
  # Vector database and embeddings
15
- # Using a compatible version of chromadb
16
- chromadb>=0.4.18,<0.5.0
17
- openai>=0.27.0
18
- langchain>=0.0.267
19
- langchain-openai>=0.0.1
20
- langchain-chroma>=0.0.1
21
 
22
  # API and tracing
23
- fastapi>=0.95.0,<0.96.0
24
- uvicorn[standard]>=0.22.0,<0.23.0
25
- pydantic>=1.10.0,<1.11.0
26
- langsmith>=0.0.52
27
- python-dotenv>=1.0.0
28
- httpx>=0.24.0,<0.25.0
29
- lightgbm>=4.0.0
30
 
31
  # New dependencies for additional features
32
- python-multipart>=0.0.6
33
  fastapi-cors # For CORS support
34
- aiofiles>=23.1.0
35
- jinja2>=3.1.2
36
  python-jose[cryptography] # For JWT tokens (authentication)
37
  passlib[bcrypt] # For password hashing
38
  pydub # For audio processing
 
1
  # Core dependencies for data processing and ML
2
+ pandas
3
+ numpy
4
+ scikit-learn # Used for ML models
5
+ joblib
6
 
7
  # NLP and sentiment analysis
8
+ nltk
9
+ vaderSentiment
10
 
11
  # Dataset downloading
12
+ kagglehub
13
 
14
  # Vector database and embeddings
15
+ chromadb
16
+ openai
17
+ langchain
18
+ langchain-openai
19
+ langchain-chroma
20
+ httpx # Required for API calls
21
 
22
  # API and tracing
23
+ fastapi
24
+ uvicorn
25
+ pydantic
26
+ langsmith
27
+ python-dotenv
28
+ lightgbm
 
29
 
30
  # New dependencies for additional features
31
+ python-multipart # For file uploads
32
  fastapi-cors # For CORS support
33
+ aiofiles # For async file operations
34
+ jinja2 # For template rendering
35
  python-jose[cryptography] # For JWT tokens (authentication)
36
  passlib[bcrypt] # For password hashing
37
  pydub # For audio processing