File size: 6,828 Bytes
8e66145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import pandas as pd
import chromadb
from chromadb.config import Settings
from langchain_openai import OpenAIEmbeddings
import os
import getpass
import shutil
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import kagglehub
from pathlib import Path
from langsmith import Client, traceable

# Download required NLTK data
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('wordnet')
nltk.download('stopwords')

# Initialize NLTK and VADER
STOPWORDS = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
analyzer = SentimentIntensityAnalyzer()

# Set OpenAI API key
if not os.environ.get("OPENAI_API_KEY"):
    os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API Key: ")

# Set LangSmith environment variables
if not os.environ.get("LANGCHAIN_API_KEY"):
    os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("Enter LangSmith API Key: ")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "MentalHealthCounselorPOC"

# Initialize LangSmith client
langsmith_client = Client()

# Define default paths
DEFAULT_OUTPUT_DIR = os.environ.get("MH_OUTPUT_DIR", "mental_health_model_artifacts")
DEFAULT_DATASET_PATH = os.environ.get("MH_DATASET_PATH", None)

# Parse command-line arguments (ignore unknown args for Jupyter/Colab)
import argparse
parser = argparse.ArgumentParser(description="Create ChromaDB vector database for Mental Health Counselor POC")
parser.add_argument('--output-dir', default=DEFAULT_OUTPUT_DIR, help="Directory for model artifacts")
parser.add_argument('--dataset-path', default=DEFAULT_DATASET_PATH, help="Path to train.csv (if already downloaded)")
args, unknown = parser.parse_known_args()  # Ignore unknown args like -f

# Set paths
output_dir = args.output_dir
chroma_db_path = os.path.join(output_dir, "chroma_db")
dataset_path = args.dataset_path

# Text preprocessing function
@traceable(run_type="tool", name="Clean Text")
def clean_text(text):
    if pd.isna(text):
        return ""
    text = str(text).lower()
    text = re.sub(r"[^a-zA-Z']", " ", text)
    tokens = word_tokenize(text)
    tokens = [lemmatizer.lemmatize(tok) for tok in tokens if tok not in STOPWORDS and len(tok) > 2]
    return " ".join(tokens)

# Response categorization function
@traceable(run_type="tool", name="Categorize Response")
def categorize_response(text):
    text = str(text).lower()
    labels = []
    if re.search(r"\?$", text.strip()):
        return "Question"
    if any(phrase in text for phrase in ["i understand", "that sounds", "i hear"]):
        labels.append("Validation")
    if any(phrase in text for phrase in ["should", "could", "try", "recommend"]):
        labels.append("Advice")
    if not labels:
        sentiment = analyzer.polarity_scores(text)
        if sentiment['compound'] > 0.3:
            labels.append("Empathetic Listening")
        else:
            labels.append("Advice")
    return "|".join(labels)

# Load dataset
@traceable(run_type="tool", name="Load Dataset")
def load_dataset():
    try:
        if dataset_path and os.path.exists(dataset_path):
            df = pd.read_csv(dataset_path)
        else:
            # Download dataset using kagglehub
            dataset = kagglehub.dataset_download("thedevastator/nlp-mental-health-conversations", path="train.csv")
            df = pd.read_csv(dataset)
        print("First 5 records:\n", df.head())
        return df
    except Exception as e:
        print(f"Error loading dataset: {e}")
        exit(1)

# Main vector database creation
@traceable(run_type="chain", name="Create Vector Database")
def create_vector_db():
    df = load_dataset()

    # Validate and clean dataset
    if not all(col in df.columns for col in ['Context', 'Response']):
        print("Error: Dataset missing required columns ('Context', 'Response')")
        exit(1)

    df = df.dropna(subset=['Context', 'Response']).drop_duplicates()
    print(f"Cleaned Dataset Shape: {df.shape}")

    # Compute response type and crisis flag
    crisis_keywords = ['suicide', 'hopeless', 'worthless', 'kill', 'harm', 'desperate', 'overwhelmed', 'alone']
    df["response_type"] = df["Response"].apply(categorize_response)
    df["response_type_single"] = df["response_type"].apply(lambda x: x.split("|")[0])
    df["crisis_flag"] = df["Context"].apply(
        lambda x: sum(1 for word in crisis_keywords if word in str(x).lower()) > 0
    )

    # Initialize ChromaDB client
    try:
        if os.path.exists(chroma_db_path):
            print(f"Clearing existing ChromaDB at {chroma_db_path}")
            shutil.rmtree(chroma_db_path)
        os.makedirs(chroma_db_path, exist_ok=True)
        chroma_client = chromadb.PersistentClient(
            path=chroma_db_path,
            settings=Settings(anonymized_telemetry=False)
        )
    except Exception as e:
        print(f"Error initializing ChromaDB: {e}")
        print("Ensure ChromaDB version is compatible (e.g., 0.5.x) and no other processes are accessing the database.")
        exit(1)

    # Initialize OpenAI embeddings
    try:
        embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
    except Exception as e:
        print(f"Error initializing OpenAI embeddings: {e}")
        exit(1)

    # Create or reset collection
    collection_name = "mental_health_conversations"
    try:
        chroma_client.delete_collection(collection_name)
        print(f"Deleted existing collection '{collection_name}' if it existed")
    except:
        print(f"No existing collection '{collection_name}' to delete")
    try:
        collection = chroma_client.create_collection(name=collection_name)
        print(f"Created new collection '{collection_name}'")
    except Exception as e:
        print(f"Error creating Chroma collection: {e}")
        exit(1)

    # Prepare documents
    documents = df["Context"].tolist()
    metadatas = [
        {
            "response": row["Response"],
            "response_type": row["response_type_single"],
            "crisis_flag": bool(row["crisis_flag"])
        }
        for _, row in df.iterrows()
    ]
    ids = [f"doc_{i}" for i in range(len(documents))]

    # Generate embeddings and add to collection
    try:
        embeddings_vectors = embeddings.embed_documents(documents)
        collection.add(
            documents=documents,
            embeddings=embeddings_vectors,
            metadatas=metadatas,
            ids=ids
        )
        print(f"Vector database created in {chroma_db_path} with {len(documents)} documents")
    except Exception as e:
        print(f"Error generating embeddings or adding to collection: {e}")
        exit(1)

if __name__ == "__main__":
    create_vector_db()