argmin commited on
Commit
510a9b0
·
1 Parent(s): f23351c
Makefile ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ setup:
2
+ pip install -r requirements.txt
3
+
4
+ run:
5
+ streamlit run app.py
6
+
7
+ test:
8
+ PYTHONPATH=./app pytest
app/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app/__init__.py ADDED
File without changes
app/config/__pycache__/model_params.cpython-310.pyc ADDED
Binary file (303 Bytes). View file
 
app/config/model_params.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ DEFAULT_PARAMS = {
2
+ "model": "gpt-4o-mini-2024-07-18",
3
+ "max_tokens": 60,
4
+ "temperature": 0.0,
5
+ "available_models": ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"], # Structured-output-compatible models
6
+ }
app/main.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from utils.prompt import generate_prompts
4
+ from utils.classification import apply_classification
5
+ from utils.validation import generate_classification_model
6
+ from utils.api import get_openai_client
7
+ from utils.tokens import estimate_token_count
8
+ from config.model_params import DEFAULT_PARAMS
9
+
10
+ st.set_page_config(layout="wide")
11
+
12
+ # Streamlit App Title
13
+ st.title("LLM-based Classifier")
14
+
15
+ # Upload Dataset
16
+ uploaded_file = st.sidebar.file_uploader("Upload a CSV file", type=["csv"])
17
+ if uploaded_file:
18
+ df = pd.read_csv(uploaded_file)
19
+ st.write("### Data Preview", df.head())
20
+
21
+ # Select Target Column
22
+ label_column = st.selectbox("Select target column (if available):", df.columns.tolist())
23
+
24
+ # Exclude Target Column from Feature Selection
25
+ if label_column: # Ensure the label column is defined
26
+ filtered_columns = [col for col in df.columns if col != label_column]
27
+ else:
28
+ filtered_columns = df.columns.tolist()
29
+
30
+ # Feature Selection
31
+ features = st.multiselect("Select features:", filtered_columns, default=filtered_columns)
32
+
33
+ # Validate Features
34
+ if label_column in features:
35
+ st.error(f"Target column '{label_column}' cannot be included in features. Please remove it.")
36
+ st.stop()
37
+
38
+ # Specify Prediction Column Name
39
+ prediction_column = st.text_input(
40
+ "Enter the name of the column to store predictions:", "Predicted Label"
41
+ )
42
+
43
+ # Define Labels and Descriptions
44
+ st.write(f"### Describe the values {prediction_column} can take")
45
+ num_labels = st.number_input("Number of unique labels:", min_value=2, step=1)
46
+
47
+ # Create columns for labels and descriptions
48
+ col1, col2 = st.columns(2)
49
+
50
+ label_descriptions = {}
51
+ for i in range(int(num_labels)):
52
+ with col1:
53
+ label = st.text_input(f"Label {i+1} name:", key=f"label_name_{i}")
54
+ with col2:
55
+ description = st.text_input(f"Label {i+1} description:", key=f"label_desc_{i}")
56
+ label_descriptions[label] = description
57
+
58
+ # Compare user-provided labels with unique target values
59
+ if label_column:
60
+ # Get unique values in the target column
61
+ unique_target_values = set(df[label_column].unique())
62
+ n_unique_target_values = len(unique_target_values)
63
+
64
+ if n_unique_target_values > 20:
65
+ st.warning(
66
+ f"The selected column '{label_column}' has {n_unique_target_values} unique values, "
67
+ f"which may not be ideal as a target for classification."
68
+ )
69
+ proceed = st.checkbox(
70
+ f"I understand and still want to use '{label_column}' as the target column."
71
+ )
72
+ if not proceed:
73
+ st.stop()
74
+
75
+ # Get user-provided labels
76
+ user_provided_labels = set(label_descriptions.keys())
77
+
78
+ # Identify missing and extra labels
79
+ missing_labels = unique_target_values - user_provided_labels
80
+ extra_labels = user_provided_labels - unique_target_values
81
+
82
+ # Display warnings for discrepancies
83
+ if missing_labels:
84
+ st.warning(
85
+ f"The following values in the target column are not accounted for in the labels: {', '.join(map(str, missing_labels))}."
86
+ )
87
+ if extra_labels:
88
+ st.warning(
89
+ f"The following user-provided labels do not match any values in the target column: {', '.join(map(str, extra_labels))}."
90
+ )
91
+
92
+ # Few-Shot Prompting
93
+ use_few_shot = st.checkbox("Use few-shot prompting with examples from the target column", value=False)
94
+
95
+ if use_few_shot and label_column:
96
+ st.info("Few-shot prompting is enabled. Examples will be selected from the dataset.")
97
+
98
+ # Group by target column and select 2 examples per class
99
+ few_shot_examples = (
100
+ df.groupby(label_column, group_keys=False)
101
+ .apply(lambda group: group.sample(min(2, len(group)), random_state=42))
102
+ )
103
+
104
+ # Show the few-shot examples for reference
105
+ st.write("### Few-Shot Examples")
106
+ st.write(few_shot_examples[[*features, label_column]])
107
+
108
+ # Remove few-shot examples from the dataset
109
+ remaining_data = df.drop(few_shot_examples.index)
110
+ else:
111
+ few_shot_examples = None
112
+ remaining_data = df
113
+
114
+ # Limit rows to 20 to control costs
115
+ if len(remaining_data) > 20:
116
+ st.warning("Only the first 20 rows of the remaining dataset will be sent to OpenAI to save costs.")
117
+
118
+ limited_data = remaining_data.head(20)
119
+
120
+ # Prepare Few-Shot Examples for Prompting
121
+ example_rows = []
122
+ if use_few_shot and few_shot_examples is not None:
123
+ for _, example in few_shot_examples.iterrows():
124
+ example_rows.append({
125
+ "features": {feature: example[feature] for feature in features},
126
+ "label": example[label_column],
127
+ })
128
+
129
+ # API Key and Model Parameters
130
+ openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password")
131
+ model_params = {
132
+ "model": st.selectbox(
133
+ "Model:",
134
+ DEFAULT_PARAMS["available_models"],
135
+ index=DEFAULT_PARAMS["available_models"].index(DEFAULT_PARAMS["model"])
136
+ ),
137
+ "temperature": st.slider("Temperature:", min_value=0.0, max_value=1.0, value=DEFAULT_PARAMS["temperature"]),
138
+ "max_tokens": DEFAULT_PARAMS["max_tokens"],
139
+ }
140
+
141
+ st.sidebar.write('**Model Config**')
142
+ st.sidebar.json(DEFAULT_PARAMS)
143
+
144
+ verbose = st.checkbox("Verbose", value=False)
145
+
146
+ # Classification Button
147
+ if st.button("Run Classification"):
148
+ if not openai_api_key:
149
+ st.error("Please provide a valid OpenAI API Key.")
150
+ else:
151
+ # Initialize OpenAI client
152
+ client = get_openai_client(api_key=openai_api_key)
153
+
154
+ # Dynamically create the Pydantic model for validation
155
+ ClassificationOutput = generate_classification_model(list(label_descriptions.keys()))
156
+
157
+ # Function to classify a single row
158
+ def classify_row(row):
159
+ # Generate system and user prompts
160
+ system_prompt, user_prompt = generate_prompts(
161
+ row=row.to_dict(),
162
+ label_descriptions=label_descriptions,
163
+ features=features,
164
+ example_rows=example_rows,
165
+ )
166
+
167
+ # Show the prompts in an expander for transparency
168
+ if verbose:
169
+ with st.expander(f"OpenAI Call Input for Row Index {row.name}"):
170
+ st.write("**System Prompt:**")
171
+ st.code(system_prompt)
172
+ st.write(f"Token Count (System Prompt): {estimate_token_count(system_prompt, model_params['model'])}")
173
+ st.write("**User Prompt:**")
174
+ st.code(user_prompt)
175
+ st.write(f"Token Count (User Prompt): {estimate_token_count(user_prompt, model_params['model'])}")
176
+
177
+ # Make the OpenAI call and validate the output
178
+ return apply_classification(
179
+ client=client,
180
+ model_params=model_params,
181
+ ClassificationOutput=ClassificationOutput,
182
+ system_prompt=system_prompt,
183
+ user_prompt=user_prompt,
184
+ verbose=verbose,
185
+ st=st
186
+ )
187
+
188
+ # Apply the classification to each row in the limited data
189
+ limited_data[prediction_column] = limited_data.apply(classify_row, axis=1)
190
+
191
+ # Display Predictions
192
+ st.write(f"### Predictions ({prediction_column})", limited_data)
193
+
194
+ # Evaluate if ground truth is available
195
+ if label_column in limited_data.columns:
196
+ from utils.evaluation import evaluate_predictions
197
+ report = evaluate_predictions(limited_data[label_column], limited_data[prediction_column])
198
+ st.write("### Evaluation Metrics")
199
+ st.json(report)
200
+ else:
201
+ st.warning(f"Target column '{label_column}' or prediction column '{prediction_column}' is missing in the data.")
app/utils/__init__.py ADDED
File without changes
app/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
app/utils/__pycache__/api.cpython-310.pyc ADDED
Binary file (1.03 kB). View file
 
app/utils/__pycache__/classification.cpython-310.pyc ADDED
Binary file (880 Bytes). View file
 
app/utils/__pycache__/evaluation.cpython-310.pyc ADDED
Binary file (779 Bytes). View file
 
app/utils/__pycache__/prompt.cpython-310.pyc ADDED
Binary file (2 kB). View file
 
app/utils/__pycache__/tokens.cpython-310.pyc ADDED
Binary file (656 Bytes). View file
 
app/utils/__pycache__/validation.cpython-310.pyc ADDED
Binary file (750 Bytes). View file
 
app/utils/api.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+
3
+ # Initialize OpenAI client
4
+ def get_openai_client(api_key):
5
+ """
6
+ Returns an OpenAI client instance with the provided API key.
7
+ """
8
+ return OpenAI(api_key=api_key)
9
+
10
+ def classify_row_chat(prompt, client, model="gpt-3.5-turbo"):
11
+ """
12
+ Sends a classification prompt to the OpenAI Chat API and returns the predicted label.
13
+
14
+ Args:
15
+ prompt (str): The user prompt to classify data.
16
+ client (OpenAI): The OpenAI client instance.
17
+ model (str): The model to use for chat completion.
18
+
19
+ Returns:
20
+ str: The predicted label.
21
+ """
22
+ response = client.chat.completions.create(
23
+ model=model,
24
+ messages=[{"role": "user", "content": prompt}]
25
+ )
26
+ return response.choices[0].message.content.strip()
app/utils/classification.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def apply_classification(client, model_params, ClassificationOutput, system_prompt, user_prompt, verbose=False, st=None):
2
+ response = client.chat.completions.create(
3
+ model=model_params["model"],
4
+ messages=[
5
+ {"role": "system", "content": system_prompt},
6
+ {"role": "user", "content": user_prompt},
7
+ ],
8
+ max_tokens=model_params["max_tokens"],
9
+ temperature=model_params["temperature"],
10
+ )
11
+ raw_prediction = response.choices[0].message.content.strip()
12
+
13
+ # Log raw prediction for debugging
14
+ if verbose and st:
15
+ st.info(f"Raw Prediction: {raw_prediction}")
16
+
17
+ # Validate and process the prediction
18
+ try:
19
+ validated_prediction = ClassificationOutput.parse_obj({"label": raw_prediction}).label
20
+ except Exception as e:
21
+ if verbose and st:
22
+ st.error(f"Invalid prediction: {raw_prediction}. Error: {e}")
23
+ return "INVALID"
24
+
25
+ return validated_prediction
26
+
app/utils/evaluation.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from sklearn.metrics import classification_report
3
+
4
+ def evaluate_predictions(y_true, y_pred):
5
+ """
6
+ Evaluates predictions by converting labels to strings and generating a classification report.
7
+
8
+ Args:
9
+ y_true (pd.Series or list): True labels.
10
+ y_pred (pd.Series or list): Predicted labels.
11
+
12
+ Returns:
13
+ dict: Classification report as a dictionary.
14
+ """
15
+ # Ensure both true and predicted labels are strings
16
+ y_true_str = pd.Series(y_true).astype(str)
17
+ y_pred_str = pd.Series(y_pred).astype(str)
18
+
19
+ # Generate classification report
20
+ report = classification_report(y_true_str, y_pred_str, output_dict=True)
21
+ return report
app/utils/prompt.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def create_classification_prompt(row, label_descriptions, features, example_rows):
2
+ """
3
+ Generates system and user prompts for classification.
4
+
5
+ Args:
6
+ row (dict): A single row of feature values.
7
+ label_descriptions (dict): Mapping of labels to their descriptions.
8
+ features (list): List of features to include in the prompt.
9
+ example_rows (list): Few-shot examples for the prompt.
10
+
11
+ Returns:
12
+ tuple: (system_prompt, user_prompt)
13
+ """
14
+ # System prompt
15
+ system_prompt = "You are a classifier. Assign one of the following labels based on the input data:\n"
16
+ for label, desc in label_descriptions.items():
17
+ system_prompt += f"- {label}: {desc}\n"
18
+
19
+ # Few-shot examples
20
+ if example_rows:
21
+ system_prompt += "\nExamples:\n"
22
+ for example in example_rows:
23
+ example_features = "; ".join(
24
+ f"{feature}: {example['features'][feature]}" for feature in features
25
+ #f"{feature}: {example.get('features', {}).get(feature, 'MISSING')}" for feature in features
26
+ )
27
+ system_prompt += f"Input: {example_features}\nLabel: {example['label']}\n"
28
+
29
+ # User prompt for the current row
30
+ user_features = "; ".join(f"{feature}: {row[feature]}" for feature in features)
31
+ user_prompt = f"Input: {user_features}\nLabel:"
32
+
33
+ return system_prompt, user_prompt
34
+
35
+
36
+ def generate_prompts(row, label_descriptions, features, example_rows):
37
+ """
38
+ Wrapper for create_classification_prompt to generate prompts for a row.
39
+
40
+ Args:
41
+ row (dict): Row of the dataset.
42
+ label_descriptions (dict): Mapping of labels to their descriptions.
43
+ features (list): List of features to include in the prompt.
44
+ example_rows (list): Few-shot examples for the prompt.
45
+
46
+ Returns:
47
+ tuple: (system_prompt, user_prompt)
48
+ """
49
+ return create_classification_prompt(
50
+ row=row,
51
+ label_descriptions=label_descriptions,
52
+ features=features,
53
+ example_rows=example_rows,
54
+ )
app/utils/tokens.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+
3
+ def estimate_token_count(prompt: str, model: str) -> int:
4
+ """
5
+ Estimate the token count for a given prompt and model.
6
+
7
+ Args:
8
+ prompt (str): The input prompt to tokenize.
9
+ model (str): The name of the model to use for token encoding.
10
+
11
+ Returns:
12
+ int: The estimated token count.
13
+ """
14
+ encoding = tiktoken.encoding_for_model(model)
15
+ return len(encoding.encode(prompt))
app/utils/validation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, create_model
2
+ from typing import Literal, List
3
+
4
+
5
+ def generate_classification_model(labels: List[str]) -> BaseModel:
6
+ """
7
+ Dynamically generates a Pydantic model for classification based on user-provided labels.
8
+
9
+ Args:
10
+ labels (List[str]): List of valid label strings.
11
+
12
+ Returns:
13
+ BaseModel: A dynamically generated Pydantic model.
14
+ """
15
+ return create_model(
16
+ "DynamicClassificationOutput",
17
+ label=(Literal[tuple(labels)], ...), # Enforce that 'label' matches one of the valid labels
18
+ )
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ kagglehub
2
+ pytest
3
+ pytest-mock
4
+ sentencepiece
5
+ sentence_transformers
6
+ streamlit
7
+ tiktoken
8
+ transformers
tests/__init__.py ADDED
File without changes
tests/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
tests/__pycache__/test_api.cpython-310-pytest-8.3.4.pyc ADDED
Binary file (1.18 kB). View file
 
tests/__pycache__/test_evaluation.cpython-310-pytest-8.3.4.pyc ADDED
Binary file (1.21 kB). View file
 
tests/__pycache__/test_prompt.cpython-310-pytest-8.3.4.pyc ADDED
Binary file (1.49 kB). View file
 
tests/__pycache__/test_validation.cpython-310-pytest-8.3.4.pyc ADDED
Binary file (1.53 kB). View file
 
tests/test_api.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import Mock
2
+ from utils.api import classify_row_chat
3
+
4
+ def test_classify_row_chat():
5
+ # Mock the OpenAI client and its response
6
+ client_mock = Mock()
7
+ client_mock.chat.completions.create.return_value = Mock(
8
+ choices=[Mock(message=Mock(content="Positive"))]
9
+ )
10
+
11
+ # Define the prompt
12
+ prompt = "Classify the following observation: Age: 25, Weight: 70\nLabel:"
13
+
14
+ # Call the classify_row_chat function with the mocked client
15
+ prediction = classify_row_chat(prompt=prompt, client=client_mock, model="gpt-3.5-turbo")
16
+
17
+ # Assert the response matches the expected label
18
+ assert prediction == "Positive", "The classification should return 'Positive'"
tests/test_evaluation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.evaluation import evaluate_predictions
2
+
3
+ def test_evaluate_predictions():
4
+ y_true = ["Positive", "Negative", "Positive"]
5
+ y_pred = ["Positive", "Negative", "Positive"]
6
+
7
+ # Test perfect match
8
+ report = evaluate_predictions(y_true, y_pred)
9
+ assert report["accuracy"] == 1.0, "Accuracy should be 100% for perfect predictions"
10
+
11
+ # Test mismatched predictions
12
+ y_pred_mismatch = ["Negative", "Negative", "Positive"]
13
+ report_mismatch = evaluate_predictions(y_true, y_pred_mismatch)
14
+ assert report_mismatch["accuracy"] < 1.0, "Accuracy should be less than 100% for mismatched predictions"
tests/test_prompt.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from utils.prompt import generate_prompts
3
+
4
+ def test_generate_prompts():
5
+ example_rows = [ # Update to match the function's parameter name
6
+ {"features": {"Age": 34, "Weight": 70, "Location": "Urban"}, "label": "Positive"},
7
+ {"features": {"Age": 25, "Weight": 60, "Location": "Rural"}, "label": "Negative"},
8
+ ]
9
+ features = ["Age", "Weight", "Location"]
10
+ label_descriptions = {
11
+ "Positive": "The sentiment is positive.",
12
+ "Negative": "The sentiment is negative.",
13
+ }
14
+
15
+ row = {"Age": 30, "Weight": 65, "Location": "Suburban"}
16
+
17
+ system_prompt, user_prompt = generate_prompts(
18
+ row=row, example_rows=example_rows, features=features, label_descriptions=label_descriptions
19
+ )
20
+ assert "Age: 34; Weight: 70; Location: Urban" in system_prompt
21
+ assert "Label: Positive" in system_prompt
22
+ assert "Label:" in user_prompt
23
+
tests/test_validation.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import ValidationError
2
+ from utils.validation import generate_classification_model
3
+
4
+ def test_classification_output_validation():
5
+ # Dynamically generate classification model
6
+ ClassificationOutput = generate_classification_model(["Positive", "Negative"])
7
+
8
+ # Test valid input
9
+ valid_output = ClassificationOutput(label="Positive")
10
+ assert valid_output.label == "Positive", "Label should be 'Positive'"
11
+
12
+ # Test invalid input
13
+ try:
14
+ ClassificationOutput(label="InvalidLabel")
15
+ except ValidationError as e:
16
+ error_message = str(e)
17
+ assert "Input should be 'Positive' or 'Negative'" in error_message, "Should raise validation error with correct message"