Spaces:
Sleeping
Sleeping
Said Lfagrouche
commited on
Commit
·
1ae4864
1
Parent(s):
727b222
Fix API: Improve error handling and add fallbacks for NLTK and model components
Browse files- api_mental_health.py +242 -118
api_mental_health.py
CHANGED
@@ -261,46 +261,94 @@ def clean_text(text):
|
|
261 |
return ""
|
262 |
text = str(text).lower()
|
263 |
text = re.sub(r"[^a-zA-Z']", " ", text)
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
265 |
tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
|
266 |
return " ".join(tokens)
|
267 |
|
268 |
# Feature engineering function
|
269 |
@traceable(run_type="tool", name="Engineer Features")
|
270 |
def engineer_features(context, response=""):
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
# Prediction function
|
306 |
@traceable(run_type="chain", name="Predict Response Type")
|
@@ -341,109 +389,185 @@ def predict_response_type(context):
|
|
341 |
# RAG suggestion function
|
342 |
@traceable(run_type="chain", name="RAG Suggestion")
|
343 |
def generate_suggestion_rag(context, response_type, crisis_flag):
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
Crisis Flag: {crisis_flag}
|
358 |
-
|
359 |
-
Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
|
360 |
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
370 |
|
371 |
-
Output in the following format:
|
372 |
-
```json
|
373 |
-
{{
|
374 |
-
"suggested_response": "Your suggested response here",
|
375 |
-
"risk_level": "Low/Moderate/High"
|
376 |
-
}}
|
377 |
-
```
|
378 |
-
"""
|
379 |
-
)
|
380 |
-
|
381 |
-
rag_chain = (
|
382 |
-
{
|
383 |
-
"context": RunnablePassthrough(),
|
384 |
-
"response_type": lambda x: response_type,
|
385 |
-
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
|
386 |
-
"retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
|
387 |
-
}
|
388 |
-
| prompt_template
|
389 |
-
| llm
|
390 |
-
)
|
391 |
-
|
392 |
-
try:
|
393 |
response = rag_chain.invoke(context)
|
394 |
return eval(response.content.strip("```json\n").strip("\n```"))
|
395 |
except Exception as e:
|
396 |
logger.error(f"Error generating RAG suggestion: {e}")
|
397 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
# Direct suggestion function
|
400 |
@traceable(run_type="chain", name="Direct Suggestion")
|
401 |
def generate_suggestion_direct(context, response_type, crisis_flag):
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
{
|
424 |
-
"suggested_response":
|
425 |
-
"risk_level":
|
426 |
-
}}
|
427 |
-
```
|
428 |
-
"""
|
429 |
-
)
|
430 |
-
|
431 |
-
direct_chain = (
|
432 |
-
{
|
433 |
-
"context": RunnablePassthrough(),
|
434 |
-
"response_type": lambda x: response_type,
|
435 |
-
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
|
436 |
}
|
437 |
-
| prompt_template
|
438 |
-
| llm
|
439 |
-
)
|
440 |
|
|
|
441 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
442 |
response = direct_chain.invoke(context)
|
443 |
return eval(response.content.strip("```json\n").strip("\n```"))
|
444 |
except Exception as e:
|
445 |
logger.error(f"Error generating direct suggestion: {e}")
|
446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
# User Profile Endpoints
|
449 |
@app.post("/users/create", response_model=UserProfile)
|
|
|
261 |
return ""
|
262 |
text = str(text).lower()
|
263 |
text = re.sub(r"[^a-zA-Z']", " ", text)
|
264 |
+
|
265 |
+
# Simple tokenization by splitting on whitespace instead of using word_tokenize
|
266 |
+
# This avoids the dependency on punkt_tab
|
267 |
+
tokens = text.split()
|
268 |
+
|
269 |
+
# Filter out stopwords and short tokens
|
270 |
tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
|
271 |
return " ".join(tokens)
|
272 |
|
273 |
# Feature engineering function
|
274 |
@traceable(run_type="tool", name="Engineer Features")
|
275 |
def engineer_features(context, response=""):
|
276 |
+
try:
|
277 |
+
context_clean = clean_text(context)
|
278 |
+
context_len = len(context_clean.split())
|
279 |
+
context_vader = analyzer.polarity_scores(context)['compound']
|
280 |
+
context_questions = context.count('?')
|
281 |
+
crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
|
282 |
+
context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
|
283 |
+
|
284 |
+
# Check if vectorizer is properly initialized
|
285 |
+
if vectorizer is None or not hasattr(vectorizer, 'transform'):
|
286 |
+
logger.warning("Vectorizer not properly initialized, using placeholder")
|
287 |
+
# Create a simple placeholder for features
|
288 |
+
features = pd.DataFrame({
|
289 |
+
"context_len": [context_len],
|
290 |
+
"context_vader": [context_vader],
|
291 |
+
"context_questions": [context_questions],
|
292 |
+
"crisis_flag": [1 if context_crisis_score > 0 else 0]
|
293 |
+
})
|
294 |
+
feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"]
|
295 |
+
return features, feature_cols
|
296 |
+
|
297 |
+
# Use vectorizer if available
|
298 |
+
context_tfidf = vectorizer.transform([context_clean]).toarray()
|
299 |
+
tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
|
300 |
+
response_tfidf = np.zeros_like(context_tfidf)
|
301 |
+
|
302 |
+
# Check if LDA model is properly initialized
|
303 |
+
if lda is None or not hasattr(lda, 'transform'):
|
304 |
+
logger.warning("LDA model not properly initialized, using zeros")
|
305 |
+
lda_topics = np.zeros((1, 10))
|
306 |
+
else:
|
307 |
+
lda_topics = lda.transform(context_tfidf)
|
308 |
+
|
309 |
+
feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
|
310 |
+
[f"topic_{i}" for i in range(10)] + tfidf_cols + \
|
311 |
+
[f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
|
312 |
+
|
313 |
+
features = pd.DataFrame({
|
314 |
+
"context_len": [context_len],
|
315 |
+
"context_vader": [context_vader],
|
316 |
+
"context_questions": [context_questions],
|
317 |
+
**{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
|
318 |
+
**{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
|
319 |
+
**{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
|
320 |
+
})
|
321 |
+
|
322 |
+
# Check if crisis classifier is properly initialized
|
323 |
+
if crisis_clf is None or not hasattr(crisis_clf, 'predict'):
|
324 |
+
logger.warning("Crisis classifier not properly initialized, using keyword detection")
|
325 |
+
crisis_flag = 1 if context_crisis_score > 0 else 0
|
326 |
+
else:
|
327 |
+
crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
|
328 |
+
crisis_flag = crisis_clf.predict(crisis_features)[0]
|
329 |
+
if context_crisis_score > 0:
|
330 |
+
crisis_flag = 1
|
331 |
+
|
332 |
+
features["crisis_flag"] = crisis_flag
|
333 |
+
|
334 |
+
return features, feature_cols
|
335 |
+
|
336 |
+
except Exception as e:
|
337 |
+
# Fallback to very basic features if anything goes wrong
|
338 |
+
logger.error(f"Error in engineer_features: {e}")
|
339 |
+
context_len = len(context.split())
|
340 |
+
context_questions = context.count('?')
|
341 |
+
crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
|
342 |
+
context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
|
343 |
+
|
344 |
+
features = pd.DataFrame({
|
345 |
+
"context_len": [context_len],
|
346 |
+
"context_vader": [0.0], # Default neutral sentiment
|
347 |
+
"context_questions": [context_questions],
|
348 |
+
"crisis_flag": [1 if context_crisis_score > 0 else 0]
|
349 |
+
})
|
350 |
+
feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"]
|
351 |
+
return features, feature_cols
|
352 |
|
353 |
# Prediction function
|
354 |
@traceable(run_type="chain", name="Predict Response Type")
|
|
|
389 |
# RAG suggestion function
|
390 |
@traceable(run_type="chain", name="RAG Suggestion")
|
391 |
def generate_suggestion_rag(context, response_type, crisis_flag):
|
392 |
+
# Check if essential components are available
|
393 |
+
if vector_store is None or llm is None:
|
394 |
+
logger.warning("Vector store or LLM not available for RAG suggestions, using fallback")
|
395 |
+
risk_level = "High" if crisis_flag else "Low"
|
396 |
+
|
397 |
+
# Simple fallback suggestions based on response type
|
398 |
+
if response_type == "Empathetic Listening":
|
399 |
+
suggestion = "I can hear that you're going through a difficult time. It sounds really challenging, and I appreciate you sharing this with me."
|
400 |
+
elif response_type == "Question":
|
401 |
+
suggestion = "Could you tell me more about how this has been affecting your daily life?"
|
402 |
+
elif response_type == "Advice":
|
403 |
+
suggestion = "It might be helpful to consider speaking with a mental health professional who can provide personalized support for what you're experiencing."
|
404 |
+
elif response_type == "Validation":
|
405 |
+
suggestion = "It's completely understandable to feel this way given what you're going through. Your feelings are valid."
|
406 |
+
else:
|
407 |
+
suggestion = "Thank you for sharing that with me. Let's explore this further together."
|
408 |
+
|
409 |
+
# Add crisis resources if needed
|
410 |
+
if crisis_flag:
|
411 |
+
suggestion += " If you're in crisis, please remember help is available 24/7 through the National Suicide Prevention Lifeline at 988."
|
412 |
+
|
413 |
+
return {
|
414 |
+
"suggested_response": suggestion,
|
415 |
+
"risk_level": risk_level
|
416 |
+
}
|
417 |
|
418 |
+
# If vector store is available, proceed with RAG
|
419 |
+
try:
|
420 |
+
results = vector_store.similarity_search_with_score(context, k=3)
|
421 |
+
retrieved_contexts = [
|
422 |
+
f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
|
423 |
+
for res in results
|
424 |
+
]
|
|
|
|
|
|
|
425 |
|
426 |
+
prompt_template = ChatPromptTemplate.from_template(
|
427 |
+
"""
|
428 |
+
You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
|
429 |
+
|
430 |
+
Patient Situation: {context}
|
431 |
+
|
432 |
+
Predicted Response Type: {response_type}
|
433 |
+
Crisis Flag: {crisis_flag}
|
434 |
+
|
435 |
+
Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
|
436 |
+
|
437 |
+
For reference, here are similar cases from past conversations:
|
438 |
+
{retrieved_contexts}
|
439 |
+
|
440 |
+
Guidelines:
|
441 |
+
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
|
442 |
+
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
|
443 |
+
- For 'Advice', provide practical, actionable suggestions.
|
444 |
+
- For 'Question', pose an open-ended question to encourage further discussion.
|
445 |
+
- For 'Validation', affirm the patient's efforts or feelings.
|
446 |
+
|
447 |
+
Output in the following format:
|
448 |
+
```json
|
449 |
+
{{
|
450 |
+
"suggested_response": "Your suggested response here",
|
451 |
+
"risk_level": "Low/Moderate/High"
|
452 |
+
}}
|
453 |
+
```
|
454 |
+
"""
|
455 |
+
)
|
456 |
|
457 |
+
rag_chain = (
|
458 |
+
{
|
459 |
+
"context": RunnablePassthrough(),
|
460 |
+
"response_type": lambda x: response_type,
|
461 |
+
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
|
462 |
+
"retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
|
463 |
+
}
|
464 |
+
| prompt_template
|
465 |
+
| llm
|
466 |
+
)
|
467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
response = rag_chain.invoke(context)
|
469 |
return eval(response.content.strip("```json\n").strip("\n```"))
|
470 |
except Exception as e:
|
471 |
logger.error(f"Error generating RAG suggestion: {e}")
|
472 |
+
risk_level = "High" if crisis_flag else "Low"
|
473 |
+
|
474 |
+
# Fallback suggestion if RAG fails
|
475 |
+
if crisis_flag:
|
476 |
+
suggestion = "I'm hearing that you're going through a very difficult time. Your safety is the most important thing right now. Would it be helpful to talk about resources that are available to support you, like the National Suicide Prevention Lifeline at 988?"
|
477 |
+
else:
|
478 |
+
suggestion = "Thank you for sharing that with me. I want to understand more about your experience and how I can best support you right now."
|
479 |
+
|
480 |
+
return {
|
481 |
+
"suggested_response": suggestion,
|
482 |
+
"risk_level": risk_level
|
483 |
+
}
|
484 |
|
485 |
# Direct suggestion function
|
486 |
@traceable(run_type="chain", name="Direct Suggestion")
|
487 |
def generate_suggestion_direct(context, response_type, crisis_flag):
|
488 |
+
# Check if essential components are available
|
489 |
+
if llm is None:
|
490 |
+
logger.warning("LLM not available for direct suggestions, using fallback")
|
491 |
+
risk_level = "High" if crisis_flag else "Low"
|
492 |
+
|
493 |
+
# Simple fallback suggestions based on response type
|
494 |
+
if response_type == "Empathetic Listening":
|
495 |
+
suggestion = "It sounds like this has been really difficult for you. I'm here to listen and support you."
|
496 |
+
elif response_type == "Question":
|
497 |
+
suggestion = "How have you been coping with these feelings recently?"
|
498 |
+
elif response_type == "Advice":
|
499 |
+
suggestion = "One thing that might help is establishing a simple morning routine with small, achievable steps."
|
500 |
+
elif response_type == "Validation":
|
501 |
+
suggestion = "What you're experiencing is a normal response to a difficult situation. Your feelings are valid."
|
502 |
+
else:
|
503 |
+
suggestion = "I appreciate you sharing this with me. Let's work through this together."
|
504 |
+
|
505 |
+
# Add crisis resources if needed
|
506 |
+
if crisis_flag:
|
507 |
+
suggestion += " Given what you've shared, I want to make sure you know about resources like the National Suicide Prevention Lifeline at 988, which is available 24/7."
|
508 |
+
|
509 |
+
return {
|
510 |
+
"suggested_response": suggestion,
|
511 |
+
"risk_level": risk_level
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
}
|
|
|
|
|
|
|
513 |
|
514 |
+
# If LLM is available, proceed with direct suggestion
|
515 |
try:
|
516 |
+
prompt_template = ChatPromptTemplate.from_template(
|
517 |
+
"""
|
518 |
+
You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
|
519 |
+
|
520 |
+
Patient Situation: {context}
|
521 |
+
|
522 |
+
Predicted Response Type: {response_type}
|
523 |
+
Crisis Flag: {crisis_flag}
|
524 |
+
|
525 |
+
Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
|
526 |
+
|
527 |
+
Guidelines:
|
528 |
+
- If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
|
529 |
+
- For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
|
530 |
+
- For 'Advice', provide practical, actionable suggestions.
|
531 |
+
- For 'Question', pose an open-ended question to encourage further discussion.
|
532 |
+
- For 'Validation', affirm the patient's efforts or feelings.
|
533 |
+
- Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
|
534 |
+
|
535 |
+
Output in the following format:
|
536 |
+
```json
|
537 |
+
{{
|
538 |
+
"suggested_response": "Your suggested response here",
|
539 |
+
"risk_level": "Low/Moderate/High"
|
540 |
+
}}
|
541 |
+
```
|
542 |
+
"""
|
543 |
+
)
|
544 |
+
|
545 |
+
direct_chain = (
|
546 |
+
{
|
547 |
+
"context": RunnablePassthrough(),
|
548 |
+
"response_type": lambda x: response_type,
|
549 |
+
"crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
|
550 |
+
}
|
551 |
+
| prompt_template
|
552 |
+
| llm
|
553 |
+
)
|
554 |
+
|
555 |
response = direct_chain.invoke(context)
|
556 |
return eval(response.content.strip("```json\n").strip("\n```"))
|
557 |
except Exception as e:
|
558 |
logger.error(f"Error generating direct suggestion: {e}")
|
559 |
+
risk_level = "High" if crisis_flag else "Low"
|
560 |
+
|
561 |
+
# Fallback suggestion if direct generation fails
|
562 |
+
if crisis_flag:
|
563 |
+
suggestion = "I'm concerned about what you're sharing. Your wellbeing is important, and I want to make sure you have support. The National Suicide Prevention Lifeline (988) has trained counselors available 24/7."
|
564 |
+
else:
|
565 |
+
suggestion = "I hear you're having a difficult time. Would you like to talk more about how these feelings have been affecting you?"
|
566 |
+
|
567 |
+
return {
|
568 |
+
"suggested_response": suggestion,
|
569 |
+
"risk_level": risk_level
|
570 |
+
}
|
571 |
|
572 |
# User Profile Endpoints
|
573 |
@app.post("/users/create", response_model=UserProfile)
|