Said Lfagrouche commited on
Commit
1ae4864
·
1 Parent(s): 727b222

Fix API: Improve error handling and add fallbacks for NLTK and model components

Browse files
Files changed (1) hide show
  1. api_mental_health.py +242 -118
api_mental_health.py CHANGED
@@ -261,46 +261,94 @@ def clean_text(text):
261
  return ""
262
  text = str(text).lower()
263
  text = re.sub(r"[^a-zA-Z']", " ", text)
264
- tokens = word_tokenize(text)
 
 
 
 
 
265
  tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
266
  return " ".join(tokens)
267
 
268
  # Feature engineering function
269
  @traceable(run_type="tool", name="Engineer Features")
270
  def engineer_features(context, response=""):
271
- context_clean = clean_text(context)
272
- context_len = len(context_clean.split())
273
- context_vader = analyzer.polarity_scores(context)['compound']
274
- context_questions = context.count('?')
275
- crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
276
- context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
277
-
278
- context_tfidf = vectorizer.transform([context_clean]).toarray()
279
- tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
280
- response_tfidf = np.zeros_like(context_tfidf)
281
-
282
- lda_topics = lda.transform(context_tfidf)
283
-
284
- feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
285
- [f"topic_{i}" for i in range(10)] + tfidf_cols + \
286
- [f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
287
-
288
- features = pd.DataFrame({
289
- "context_len": [context_len],
290
- "context_vader": [context_vader],
291
- "context_questions": [context_questions],
292
- **{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
293
- **{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
294
- **{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
295
- })
296
-
297
- crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
298
- crisis_flag = crisis_clf.predict(crisis_features)[0]
299
- if context_crisis_score > 0:
300
- crisis_flag = 1
301
- features["crisis_flag"] = crisis_flag
302
-
303
- return features, feature_cols
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  # Prediction function
306
  @traceable(run_type="chain", name="Predict Response Type")
@@ -341,109 +389,185 @@ def predict_response_type(context):
341
  # RAG suggestion function
342
  @traceable(run_type="chain", name="RAG Suggestion")
343
  def generate_suggestion_rag(context, response_type, crisis_flag):
344
- results = vector_store.similarity_search_with_score(context, k=3)
345
- retrieved_contexts = [
346
- f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
347
- for res in results
348
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
- prompt_template = ChatPromptTemplate.from_template(
351
- """
352
- You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
353
-
354
- Patient Situation: {context}
355
-
356
- Predicted Response Type: {response_type}
357
- Crisis Flag: {crisis_flag}
358
-
359
- Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
360
 
361
- For reference, here are similar cases from past conversations:
362
- {retrieved_contexts}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
- Guidelines:
365
- - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
366
- - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
367
- - For 'Advice', provide practical, actionable suggestions.
368
- - For 'Question', pose an open-ended question to encourage further discussion.
369
- - For 'Validation', affirm the patient's efforts or feelings.
 
 
 
 
370
 
371
- Output in the following format:
372
- ```json
373
- {{
374
- "suggested_response": "Your suggested response here",
375
- "risk_level": "Low/Moderate/High"
376
- }}
377
- ```
378
- """
379
- )
380
-
381
- rag_chain = (
382
- {
383
- "context": RunnablePassthrough(),
384
- "response_type": lambda x: response_type,
385
- "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
386
- "retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
387
- }
388
- | prompt_template
389
- | llm
390
- )
391
-
392
- try:
393
  response = rag_chain.invoke(context)
394
  return eval(response.content.strip("```json\n").strip("\n```"))
395
  except Exception as e:
396
  logger.error(f"Error generating RAG suggestion: {e}")
397
- raise HTTPException(status_code=500, detail=f"Error generating RAG suggestion: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  # Direct suggestion function
400
  @traceable(run_type="chain", name="Direct Suggestion")
401
  def generate_suggestion_direct(context, response_type, crisis_flag):
402
- prompt_template = ChatPromptTemplate.from_template(
403
- """
404
- You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
405
-
406
- Patient Situation: {context}
407
-
408
- Predicted Response Type: {response_type}
409
- Crisis Flag: {crisis_flag}
410
-
411
- Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
412
-
413
- Guidelines:
414
- - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
415
- - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
416
- - For 'Advice', provide practical, actionable suggestions.
417
- - For 'Question', pose an open-ended question to encourage further discussion.
418
- - For 'Validation', affirm the patient's efforts or feelings.
419
- - Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
420
-
421
- Output in the following format:
422
- ```json
423
- {{
424
- "suggested_response": "Your suggested response here",
425
- "risk_level": "Low/Moderate/High"
426
- }}
427
- ```
428
- """
429
- )
430
-
431
- direct_chain = (
432
- {
433
- "context": RunnablePassthrough(),
434
- "response_type": lambda x: response_type,
435
- "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
436
  }
437
- | prompt_template
438
- | llm
439
- )
440
 
 
441
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  response = direct_chain.invoke(context)
443
  return eval(response.content.strip("```json\n").strip("\n```"))
444
  except Exception as e:
445
  logger.error(f"Error generating direct suggestion: {e}")
446
- raise HTTPException(status_code=500, detail=f"Error generating direct suggestion: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  # User Profile Endpoints
449
  @app.post("/users/create", response_model=UserProfile)
 
261
  return ""
262
  text = str(text).lower()
263
  text = re.sub(r"[^a-zA-Z']", " ", text)
264
+
265
+ # Simple tokenization by splitting on whitespace instead of using word_tokenize
266
+ # This avoids the dependency on punkt_tab
267
+ tokens = text.split()
268
+
269
+ # Filter out stopwords and short tokens
270
  tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
271
  return " ".join(tokens)
272
 
273
  # Feature engineering function
274
  @traceable(run_type="tool", name="Engineer Features")
275
  def engineer_features(context, response=""):
276
+ try:
277
+ context_clean = clean_text(context)
278
+ context_len = len(context_clean.split())
279
+ context_vader = analyzer.polarity_scores(context)['compound']
280
+ context_questions = context.count('?')
281
+ crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
282
+ context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
283
+
284
+ # Check if vectorizer is properly initialized
285
+ if vectorizer is None or not hasattr(vectorizer, 'transform'):
286
+ logger.warning("Vectorizer not properly initialized, using placeholder")
287
+ # Create a simple placeholder for features
288
+ features = pd.DataFrame({
289
+ "context_len": [context_len],
290
+ "context_vader": [context_vader],
291
+ "context_questions": [context_questions],
292
+ "crisis_flag": [1 if context_crisis_score > 0 else 0]
293
+ })
294
+ feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"]
295
+ return features, feature_cols
296
+
297
+ # Use vectorizer if available
298
+ context_tfidf = vectorizer.transform([context_clean]).toarray()
299
+ tfidf_cols = [f"tfidf_context_{i}" for i in range(context_tfidf.shape[1])]
300
+ response_tfidf = np.zeros_like(context_tfidf)
301
+
302
+ # Check if LDA model is properly initialized
303
+ if lda is None or not hasattr(lda, 'transform'):
304
+ logger.warning("LDA model not properly initialized, using zeros")
305
+ lda_topics = np.zeros((1, 10))
306
+ else:
307
+ lda_topics = lda.transform(context_tfidf)
308
+
309
+ feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"] + \
310
+ [f"topic_{i}" for i in range(10)] + tfidf_cols + \
311
+ [f"tfidf_response_{i}" for i in range(response_tfidf.shape[1])]
312
+
313
+ features = pd.DataFrame({
314
+ "context_len": [context_len],
315
+ "context_vader": [context_vader],
316
+ "context_questions": [context_questions],
317
+ **{f"topic_{i}": [lda_topics[0][i]] for i in range(10)},
318
+ **{f"tfidf_context_{i}": [context_tfidf[0][i]] for i in range(context_tfidf.shape[1])},
319
+ **{f"tfidf_response_{i}": [response_tfidf[0][i]] for i in range(response_tfidf.shape[1])},
320
+ })
321
+
322
+ # Check if crisis classifier is properly initialized
323
+ if crisis_clf is None or not hasattr(crisis_clf, 'predict'):
324
+ logger.warning("Crisis classifier not properly initialized, using keyword detection")
325
+ crisis_flag = 1 if context_crisis_score > 0 else 0
326
+ else:
327
+ crisis_features = features[["context_len", "context_vader", "context_questions"] + [f"topic_{i}" for i in range(10)]]
328
+ crisis_flag = crisis_clf.predict(crisis_features)[0]
329
+ if context_crisis_score > 0:
330
+ crisis_flag = 1
331
+
332
+ features["crisis_flag"] = crisis_flag
333
+
334
+ return features, feature_cols
335
+
336
+ except Exception as e:
337
+ # Fallback to very basic features if anything goes wrong
338
+ logger.error(f"Error in engineer_features: {e}")
339
+ context_len = len(context.split())
340
+ context_questions = context.count('?')
341
+ crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
342
+ context_crisis_score = sum(1 for word in crisis_keywords if word in context.lower())
343
+
344
+ features = pd.DataFrame({
345
+ "context_len": [context_len],
346
+ "context_vader": [0.0], # Default neutral sentiment
347
+ "context_questions": [context_questions],
348
+ "crisis_flag": [1 if context_crisis_score > 0 else 0]
349
+ })
350
+ feature_cols = ["context_len", "context_vader", "context_questions", "crisis_flag"]
351
+ return features, feature_cols
352
 
353
  # Prediction function
354
  @traceable(run_type="chain", name="Predict Response Type")
 
389
  # RAG suggestion function
390
  @traceable(run_type="chain", name="RAG Suggestion")
391
  def generate_suggestion_rag(context, response_type, crisis_flag):
392
+ # Check if essential components are available
393
+ if vector_store is None or llm is None:
394
+ logger.warning("Vector store or LLM not available for RAG suggestions, using fallback")
395
+ risk_level = "High" if crisis_flag else "Low"
396
+
397
+ # Simple fallback suggestions based on response type
398
+ if response_type == "Empathetic Listening":
399
+ suggestion = "I can hear that you're going through a difficult time. It sounds really challenging, and I appreciate you sharing this with me."
400
+ elif response_type == "Question":
401
+ suggestion = "Could you tell me more about how this has been affecting your daily life?"
402
+ elif response_type == "Advice":
403
+ suggestion = "It might be helpful to consider speaking with a mental health professional who can provide personalized support for what you're experiencing."
404
+ elif response_type == "Validation":
405
+ suggestion = "It's completely understandable to feel this way given what you're going through. Your feelings are valid."
406
+ else:
407
+ suggestion = "Thank you for sharing that with me. Let's explore this further together."
408
+
409
+ # Add crisis resources if needed
410
+ if crisis_flag:
411
+ suggestion += " If you're in crisis, please remember help is available 24/7 through the National Suicide Prevention Lifeline at 988."
412
+
413
+ return {
414
+ "suggested_response": suggestion,
415
+ "risk_level": risk_level
416
+ }
417
 
418
+ # If vector store is available, proceed with RAG
419
+ try:
420
+ results = vector_store.similarity_search_with_score(context, k=3)
421
+ retrieved_contexts = [
422
+ f"Patient: {res[0].page_content}\nCounselor: {res[0].metadata['response']} (Type: {res[0].metadata['response_type']}, Crisis: {res[0].metadata['crisis_flag']}, Score: {res[1]:.2f})"
423
+ for res in results
424
+ ]
 
 
 
425
 
426
+ prompt_template = ChatPromptTemplate.from_template(
427
+ """
428
+ You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
429
+
430
+ Patient Situation: {context}
431
+
432
+ Predicted Response Type: {response_type}
433
+ Crisis Flag: {crisis_flag}
434
+
435
+ Based on the predicted response type and crisis flag, provide a suggested response for the counselor to use with the patient. The response should align with the response type ({response_type}) and be sensitive to the crisis level.
436
+
437
+ For reference, here are similar cases from past conversations:
438
+ {retrieved_contexts}
439
+
440
+ Guidelines:
441
+ - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
442
+ - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
443
+ - For 'Advice', provide practical, actionable suggestions.
444
+ - For 'Question', pose an open-ended question to encourage further discussion.
445
+ - For 'Validation', affirm the patient's efforts or feelings.
446
+
447
+ Output in the following format:
448
+ ```json
449
+ {{
450
+ "suggested_response": "Your suggested response here",
451
+ "risk_level": "Low/Moderate/High"
452
+ }}
453
+ ```
454
+ """
455
+ )
456
 
457
+ rag_chain = (
458
+ {
459
+ "context": RunnablePassthrough(),
460
+ "response_type": lambda x: response_type,
461
+ "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis",
462
+ "retrieved_contexts": lambda x: "\n".join(retrieved_contexts)
463
+ }
464
+ | prompt_template
465
+ | llm
466
+ )
467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  response = rag_chain.invoke(context)
469
  return eval(response.content.strip("```json\n").strip("\n```"))
470
  except Exception as e:
471
  logger.error(f"Error generating RAG suggestion: {e}")
472
+ risk_level = "High" if crisis_flag else "Low"
473
+
474
+ # Fallback suggestion if RAG fails
475
+ if crisis_flag:
476
+ suggestion = "I'm hearing that you're going through a very difficult time. Your safety is the most important thing right now. Would it be helpful to talk about resources that are available to support you, like the National Suicide Prevention Lifeline at 988?"
477
+ else:
478
+ suggestion = "Thank you for sharing that with me. I want to understand more about your experience and how I can best support you right now."
479
+
480
+ return {
481
+ "suggested_response": suggestion,
482
+ "risk_level": risk_level
483
+ }
484
 
485
  # Direct suggestion function
486
  @traceable(run_type="chain", name="Direct Suggestion")
487
  def generate_suggestion_direct(context, response_type, crisis_flag):
488
+ # Check if essential components are available
489
+ if llm is None:
490
+ logger.warning("LLM not available for direct suggestions, using fallback")
491
+ risk_level = "High" if crisis_flag else "Low"
492
+
493
+ # Simple fallback suggestions based on response type
494
+ if response_type == "Empathetic Listening":
495
+ suggestion = "It sounds like this has been really difficult for you. I'm here to listen and support you."
496
+ elif response_type == "Question":
497
+ suggestion = "How have you been coping with these feelings recently?"
498
+ elif response_type == "Advice":
499
+ suggestion = "One thing that might help is establishing a simple morning routine with small, achievable steps."
500
+ elif response_type == "Validation":
501
+ suggestion = "What you're experiencing is a normal response to a difficult situation. Your feelings are valid."
502
+ else:
503
+ suggestion = "I appreciate you sharing this with me. Let's work through this together."
504
+
505
+ # Add crisis resources if needed
506
+ if crisis_flag:
507
+ suggestion += " Given what you've shared, I want to make sure you know about resources like the National Suicide Prevention Lifeline at 988, which is available 24/7."
508
+
509
+ return {
510
+ "suggested_response": suggestion,
511
+ "risk_level": risk_level
 
 
 
 
 
 
 
 
 
 
512
  }
 
 
 
513
 
514
+ # If LLM is available, proceed with direct suggestion
515
  try:
516
+ prompt_template = ChatPromptTemplate.from_template(
517
+ """
518
+ You are an expert mental health counseling assistant. A counselor has provided the following patient situation:
519
+
520
+ Patient Situation: {context}
521
+
522
+ Predicted Response Type: {response_type}
523
+ Crisis Flag: {crisis_flag}
524
+
525
+ Provide a suggested response for the counselor to use with the patient, aligned with the response type ({response_type}) and sensitive to the crisis level.
526
+
527
+ Guidelines:
528
+ - If Crisis Flag is True, prioritize safety, empathy, and suggest immediate resources (e.g., National Suicide Prevention Lifeline at 988).
529
+ - For 'Empathetic Listening', focus on validating feelings without giving direct advice or questions.
530
+ - For 'Advice', provide practical, actionable suggestions.
531
+ - For 'Question', pose an open-ended question to encourage further discussion.
532
+ - For 'Validation', affirm the patient's efforts or feelings.
533
+ - Strictly adhere to the response type. For 'Empathetic Listening', do not include questions or advice.
534
+
535
+ Output in the following format:
536
+ ```json
537
+ {{
538
+ "suggested_response": "Your suggested response here",
539
+ "risk_level": "Low/Moderate/High"
540
+ }}
541
+ ```
542
+ """
543
+ )
544
+
545
+ direct_chain = (
546
+ {
547
+ "context": RunnablePassthrough(),
548
+ "response_type": lambda x: response_type,
549
+ "crisis_flag": lambda x: "Crisis" if crisis_flag else "No Crisis"
550
+ }
551
+ | prompt_template
552
+ | llm
553
+ )
554
+
555
  response = direct_chain.invoke(context)
556
  return eval(response.content.strip("```json\n").strip("\n```"))
557
  except Exception as e:
558
  logger.error(f"Error generating direct suggestion: {e}")
559
+ risk_level = "High" if crisis_flag else "Low"
560
+
561
+ # Fallback suggestion if direct generation fails
562
+ if crisis_flag:
563
+ suggestion = "I'm concerned about what you're sharing. Your wellbeing is important, and I want to make sure you have support. The National Suicide Prevention Lifeline (988) has trained counselors available 24/7."
564
+ else:
565
+ suggestion = "I hear you're having a difficult time. Would you like to talk more about how these feelings have been affecting you?"
566
+
567
+ return {
568
+ "suggested_response": suggestion,
569
+ "risk_level": risk_level
570
+ }
571
 
572
  # User Profile Endpoints
573
  @app.post("/users/create", response_model=UserProfile)