Spaces:
Sleeping
Sleeping
add files
Browse files- Makefile +8 -0
- app/.DS_Store +0 -0
- app/__init__.py +0 -0
- app/config/__pycache__/model_params.cpython-310.pyc +0 -0
- app/config/model_params.py +6 -0
- app/main.py +201 -0
- app/utils/__init__.py +0 -0
- app/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- app/utils/__pycache__/api.cpython-310.pyc +0 -0
- app/utils/__pycache__/classification.cpython-310.pyc +0 -0
- app/utils/__pycache__/evaluation.cpython-310.pyc +0 -0
- app/utils/__pycache__/prompt.cpython-310.pyc +0 -0
- app/utils/__pycache__/tokens.cpython-310.pyc +0 -0
- app/utils/__pycache__/validation.cpython-310.pyc +0 -0
- app/utils/api.py +26 -0
- app/utils/classification.py +26 -0
- app/utils/evaluation.py +21 -0
- app/utils/prompt.py +54 -0
- app/utils/tokens.py +15 -0
- app/utils/validation.py +18 -0
- requirements.txt +8 -0
- tests/__init__.py +0 -0
- tests/__pycache__/__init__.cpython-310.pyc +0 -0
- tests/__pycache__/test_api.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_evaluation.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_prompt.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_validation.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/test_api.py +18 -0
- tests/test_evaluation.py +14 -0
- tests/test_prompt.py +23 -0
- tests/test_validation.py +17 -0
Makefile
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
setup:
|
2 |
+
pip install -r requirements.txt
|
3 |
+
|
4 |
+
run:
|
5 |
+
streamlit run app.py
|
6 |
+
|
7 |
+
test:
|
8 |
+
PYTHONPATH=./app pytest
|
app/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app/__init__.py
ADDED
File without changes
|
app/config/__pycache__/model_params.cpython-310.pyc
ADDED
Binary file (303 Bytes). View file
|
|
app/config/model_params.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_PARAMS = {
|
2 |
+
"model": "gpt-4o-mini-2024-07-18",
|
3 |
+
"max_tokens": 60,
|
4 |
+
"temperature": 0.0,
|
5 |
+
"available_models": ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"], # Structured-output-compatible models
|
6 |
+
}
|
app/main.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from utils.prompt import generate_prompts
|
4 |
+
from utils.classification import apply_classification
|
5 |
+
from utils.validation import generate_classification_model
|
6 |
+
from utils.api import get_openai_client
|
7 |
+
from utils.tokens import estimate_token_count
|
8 |
+
from config.model_params import DEFAULT_PARAMS
|
9 |
+
|
10 |
+
st.set_page_config(layout="wide")
|
11 |
+
|
12 |
+
# Streamlit App Title
|
13 |
+
st.title("LLM-based Classifier")
|
14 |
+
|
15 |
+
# Upload Dataset
|
16 |
+
uploaded_file = st.sidebar.file_uploader("Upload a CSV file", type=["csv"])
|
17 |
+
if uploaded_file:
|
18 |
+
df = pd.read_csv(uploaded_file)
|
19 |
+
st.write("### Data Preview", df.head())
|
20 |
+
|
21 |
+
# Select Target Column
|
22 |
+
label_column = st.selectbox("Select target column (if available):", df.columns.tolist())
|
23 |
+
|
24 |
+
# Exclude Target Column from Feature Selection
|
25 |
+
if label_column: # Ensure the label column is defined
|
26 |
+
filtered_columns = [col for col in df.columns if col != label_column]
|
27 |
+
else:
|
28 |
+
filtered_columns = df.columns.tolist()
|
29 |
+
|
30 |
+
# Feature Selection
|
31 |
+
features = st.multiselect("Select features:", filtered_columns, default=filtered_columns)
|
32 |
+
|
33 |
+
# Validate Features
|
34 |
+
if label_column in features:
|
35 |
+
st.error(f"Target column '{label_column}' cannot be included in features. Please remove it.")
|
36 |
+
st.stop()
|
37 |
+
|
38 |
+
# Specify Prediction Column Name
|
39 |
+
prediction_column = st.text_input(
|
40 |
+
"Enter the name of the column to store predictions:", "Predicted Label"
|
41 |
+
)
|
42 |
+
|
43 |
+
# Define Labels and Descriptions
|
44 |
+
st.write(f"### Describe the values {prediction_column} can take")
|
45 |
+
num_labels = st.number_input("Number of unique labels:", min_value=2, step=1)
|
46 |
+
|
47 |
+
# Create columns for labels and descriptions
|
48 |
+
col1, col2 = st.columns(2)
|
49 |
+
|
50 |
+
label_descriptions = {}
|
51 |
+
for i in range(int(num_labels)):
|
52 |
+
with col1:
|
53 |
+
label = st.text_input(f"Label {i+1} name:", key=f"label_name_{i}")
|
54 |
+
with col2:
|
55 |
+
description = st.text_input(f"Label {i+1} description:", key=f"label_desc_{i}")
|
56 |
+
label_descriptions[label] = description
|
57 |
+
|
58 |
+
# Compare user-provided labels with unique target values
|
59 |
+
if label_column:
|
60 |
+
# Get unique values in the target column
|
61 |
+
unique_target_values = set(df[label_column].unique())
|
62 |
+
n_unique_target_values = len(unique_target_values)
|
63 |
+
|
64 |
+
if n_unique_target_values > 20:
|
65 |
+
st.warning(
|
66 |
+
f"The selected column '{label_column}' has {n_unique_target_values} unique values, "
|
67 |
+
f"which may not be ideal as a target for classification."
|
68 |
+
)
|
69 |
+
proceed = st.checkbox(
|
70 |
+
f"I understand and still want to use '{label_column}' as the target column."
|
71 |
+
)
|
72 |
+
if not proceed:
|
73 |
+
st.stop()
|
74 |
+
|
75 |
+
# Get user-provided labels
|
76 |
+
user_provided_labels = set(label_descriptions.keys())
|
77 |
+
|
78 |
+
# Identify missing and extra labels
|
79 |
+
missing_labels = unique_target_values - user_provided_labels
|
80 |
+
extra_labels = user_provided_labels - unique_target_values
|
81 |
+
|
82 |
+
# Display warnings for discrepancies
|
83 |
+
if missing_labels:
|
84 |
+
st.warning(
|
85 |
+
f"The following values in the target column are not accounted for in the labels: {', '.join(map(str, missing_labels))}."
|
86 |
+
)
|
87 |
+
if extra_labels:
|
88 |
+
st.warning(
|
89 |
+
f"The following user-provided labels do not match any values in the target column: {', '.join(map(str, extra_labels))}."
|
90 |
+
)
|
91 |
+
|
92 |
+
# Few-Shot Prompting
|
93 |
+
use_few_shot = st.checkbox("Use few-shot prompting with examples from the target column", value=False)
|
94 |
+
|
95 |
+
if use_few_shot and label_column:
|
96 |
+
st.info("Few-shot prompting is enabled. Examples will be selected from the dataset.")
|
97 |
+
|
98 |
+
# Group by target column and select 2 examples per class
|
99 |
+
few_shot_examples = (
|
100 |
+
df.groupby(label_column, group_keys=False)
|
101 |
+
.apply(lambda group: group.sample(min(2, len(group)), random_state=42))
|
102 |
+
)
|
103 |
+
|
104 |
+
# Show the few-shot examples for reference
|
105 |
+
st.write("### Few-Shot Examples")
|
106 |
+
st.write(few_shot_examples[[*features, label_column]])
|
107 |
+
|
108 |
+
# Remove few-shot examples from the dataset
|
109 |
+
remaining_data = df.drop(few_shot_examples.index)
|
110 |
+
else:
|
111 |
+
few_shot_examples = None
|
112 |
+
remaining_data = df
|
113 |
+
|
114 |
+
# Limit rows to 20 to control costs
|
115 |
+
if len(remaining_data) > 20:
|
116 |
+
st.warning("Only the first 20 rows of the remaining dataset will be sent to OpenAI to save costs.")
|
117 |
+
|
118 |
+
limited_data = remaining_data.head(20)
|
119 |
+
|
120 |
+
# Prepare Few-Shot Examples for Prompting
|
121 |
+
example_rows = []
|
122 |
+
if use_few_shot and few_shot_examples is not None:
|
123 |
+
for _, example in few_shot_examples.iterrows():
|
124 |
+
example_rows.append({
|
125 |
+
"features": {feature: example[feature] for feature in features},
|
126 |
+
"label": example[label_column],
|
127 |
+
})
|
128 |
+
|
129 |
+
# API Key and Model Parameters
|
130 |
+
openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password")
|
131 |
+
model_params = {
|
132 |
+
"model": st.selectbox(
|
133 |
+
"Model:",
|
134 |
+
DEFAULT_PARAMS["available_models"],
|
135 |
+
index=DEFAULT_PARAMS["available_models"].index(DEFAULT_PARAMS["model"])
|
136 |
+
),
|
137 |
+
"temperature": st.slider("Temperature:", min_value=0.0, max_value=1.0, value=DEFAULT_PARAMS["temperature"]),
|
138 |
+
"max_tokens": DEFAULT_PARAMS["max_tokens"],
|
139 |
+
}
|
140 |
+
|
141 |
+
st.sidebar.write('**Model Config**')
|
142 |
+
st.sidebar.json(DEFAULT_PARAMS)
|
143 |
+
|
144 |
+
verbose = st.checkbox("Verbose", value=False)
|
145 |
+
|
146 |
+
# Classification Button
|
147 |
+
if st.button("Run Classification"):
|
148 |
+
if not openai_api_key:
|
149 |
+
st.error("Please provide a valid OpenAI API Key.")
|
150 |
+
else:
|
151 |
+
# Initialize OpenAI client
|
152 |
+
client = get_openai_client(api_key=openai_api_key)
|
153 |
+
|
154 |
+
# Dynamically create the Pydantic model for validation
|
155 |
+
ClassificationOutput = generate_classification_model(list(label_descriptions.keys()))
|
156 |
+
|
157 |
+
# Function to classify a single row
|
158 |
+
def classify_row(row):
|
159 |
+
# Generate system and user prompts
|
160 |
+
system_prompt, user_prompt = generate_prompts(
|
161 |
+
row=row.to_dict(),
|
162 |
+
label_descriptions=label_descriptions,
|
163 |
+
features=features,
|
164 |
+
example_rows=example_rows,
|
165 |
+
)
|
166 |
+
|
167 |
+
# Show the prompts in an expander for transparency
|
168 |
+
if verbose:
|
169 |
+
with st.expander(f"OpenAI Call Input for Row Index {row.name}"):
|
170 |
+
st.write("**System Prompt:**")
|
171 |
+
st.code(system_prompt)
|
172 |
+
st.write(f"Token Count (System Prompt): {estimate_token_count(system_prompt, model_params['model'])}")
|
173 |
+
st.write("**User Prompt:**")
|
174 |
+
st.code(user_prompt)
|
175 |
+
st.write(f"Token Count (User Prompt): {estimate_token_count(user_prompt, model_params['model'])}")
|
176 |
+
|
177 |
+
# Make the OpenAI call and validate the output
|
178 |
+
return apply_classification(
|
179 |
+
client=client,
|
180 |
+
model_params=model_params,
|
181 |
+
ClassificationOutput=ClassificationOutput,
|
182 |
+
system_prompt=system_prompt,
|
183 |
+
user_prompt=user_prompt,
|
184 |
+
verbose=verbose,
|
185 |
+
st=st
|
186 |
+
)
|
187 |
+
|
188 |
+
# Apply the classification to each row in the limited data
|
189 |
+
limited_data[prediction_column] = limited_data.apply(classify_row, axis=1)
|
190 |
+
|
191 |
+
# Display Predictions
|
192 |
+
st.write(f"### Predictions ({prediction_column})", limited_data)
|
193 |
+
|
194 |
+
# Evaluate if ground truth is available
|
195 |
+
if label_column in limited_data.columns:
|
196 |
+
from utils.evaluation import evaluate_predictions
|
197 |
+
report = evaluate_predictions(limited_data[label_column], limited_data[prediction_column])
|
198 |
+
st.write("### Evaluation Metrics")
|
199 |
+
st.json(report)
|
200 |
+
else:
|
201 |
+
st.warning(f"Target column '{label_column}' or prediction column '{prediction_column}' is missing in the data.")
|
app/utils/__init__.py
ADDED
File without changes
|
app/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
app/utils/__pycache__/api.cpython-310.pyc
ADDED
Binary file (1.03 kB). View file
|
|
app/utils/__pycache__/classification.cpython-310.pyc
ADDED
Binary file (880 Bytes). View file
|
|
app/utils/__pycache__/evaluation.cpython-310.pyc
ADDED
Binary file (779 Bytes). View file
|
|
app/utils/__pycache__/prompt.cpython-310.pyc
ADDED
Binary file (2 kB). View file
|
|
app/utils/__pycache__/tokens.cpython-310.pyc
ADDED
Binary file (656 Bytes). View file
|
|
app/utils/__pycache__/validation.cpython-310.pyc
ADDED
Binary file (750 Bytes). View file
|
|
app/utils/api.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
|
3 |
+
# Initialize OpenAI client
|
4 |
+
def get_openai_client(api_key):
|
5 |
+
"""
|
6 |
+
Returns an OpenAI client instance with the provided API key.
|
7 |
+
"""
|
8 |
+
return OpenAI(api_key=api_key)
|
9 |
+
|
10 |
+
def classify_row_chat(prompt, client, model="gpt-3.5-turbo"):
|
11 |
+
"""
|
12 |
+
Sends a classification prompt to the OpenAI Chat API and returns the predicted label.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
prompt (str): The user prompt to classify data.
|
16 |
+
client (OpenAI): The OpenAI client instance.
|
17 |
+
model (str): The model to use for chat completion.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: The predicted label.
|
21 |
+
"""
|
22 |
+
response = client.chat.completions.create(
|
23 |
+
model=model,
|
24 |
+
messages=[{"role": "user", "content": prompt}]
|
25 |
+
)
|
26 |
+
return response.choices[0].message.content.strip()
|
app/utils/classification.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def apply_classification(client, model_params, ClassificationOutput, system_prompt, user_prompt, verbose=False, st=None):
|
2 |
+
response = client.chat.completions.create(
|
3 |
+
model=model_params["model"],
|
4 |
+
messages=[
|
5 |
+
{"role": "system", "content": system_prompt},
|
6 |
+
{"role": "user", "content": user_prompt},
|
7 |
+
],
|
8 |
+
max_tokens=model_params["max_tokens"],
|
9 |
+
temperature=model_params["temperature"],
|
10 |
+
)
|
11 |
+
raw_prediction = response.choices[0].message.content.strip()
|
12 |
+
|
13 |
+
# Log raw prediction for debugging
|
14 |
+
if verbose and st:
|
15 |
+
st.info(f"Raw Prediction: {raw_prediction}")
|
16 |
+
|
17 |
+
# Validate and process the prediction
|
18 |
+
try:
|
19 |
+
validated_prediction = ClassificationOutput.parse_obj({"label": raw_prediction}).label
|
20 |
+
except Exception as e:
|
21 |
+
if verbose and st:
|
22 |
+
st.error(f"Invalid prediction: {raw_prediction}. Error: {e}")
|
23 |
+
return "INVALID"
|
24 |
+
|
25 |
+
return validated_prediction
|
26 |
+
|
app/utils/evaluation.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from sklearn.metrics import classification_report
|
3 |
+
|
4 |
+
def evaluate_predictions(y_true, y_pred):
|
5 |
+
"""
|
6 |
+
Evaluates predictions by converting labels to strings and generating a classification report.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
y_true (pd.Series or list): True labels.
|
10 |
+
y_pred (pd.Series or list): Predicted labels.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
dict: Classification report as a dictionary.
|
14 |
+
"""
|
15 |
+
# Ensure both true and predicted labels are strings
|
16 |
+
y_true_str = pd.Series(y_true).astype(str)
|
17 |
+
y_pred_str = pd.Series(y_pred).astype(str)
|
18 |
+
|
19 |
+
# Generate classification report
|
20 |
+
report = classification_report(y_true_str, y_pred_str, output_dict=True)
|
21 |
+
return report
|
app/utils/prompt.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def create_classification_prompt(row, label_descriptions, features, example_rows):
|
2 |
+
"""
|
3 |
+
Generates system and user prompts for classification.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
row (dict): A single row of feature values.
|
7 |
+
label_descriptions (dict): Mapping of labels to their descriptions.
|
8 |
+
features (list): List of features to include in the prompt.
|
9 |
+
example_rows (list): Few-shot examples for the prompt.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
tuple: (system_prompt, user_prompt)
|
13 |
+
"""
|
14 |
+
# System prompt
|
15 |
+
system_prompt = "You are a classifier. Assign one of the following labels based on the input data:\n"
|
16 |
+
for label, desc in label_descriptions.items():
|
17 |
+
system_prompt += f"- {label}: {desc}\n"
|
18 |
+
|
19 |
+
# Few-shot examples
|
20 |
+
if example_rows:
|
21 |
+
system_prompt += "\nExamples:\n"
|
22 |
+
for example in example_rows:
|
23 |
+
example_features = "; ".join(
|
24 |
+
f"{feature}: {example['features'][feature]}" for feature in features
|
25 |
+
#f"{feature}: {example.get('features', {}).get(feature, 'MISSING')}" for feature in features
|
26 |
+
)
|
27 |
+
system_prompt += f"Input: {example_features}\nLabel: {example['label']}\n"
|
28 |
+
|
29 |
+
# User prompt for the current row
|
30 |
+
user_features = "; ".join(f"{feature}: {row[feature]}" for feature in features)
|
31 |
+
user_prompt = f"Input: {user_features}\nLabel:"
|
32 |
+
|
33 |
+
return system_prompt, user_prompt
|
34 |
+
|
35 |
+
|
36 |
+
def generate_prompts(row, label_descriptions, features, example_rows):
|
37 |
+
"""
|
38 |
+
Wrapper for create_classification_prompt to generate prompts for a row.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
row (dict): Row of the dataset.
|
42 |
+
label_descriptions (dict): Mapping of labels to their descriptions.
|
43 |
+
features (list): List of features to include in the prompt.
|
44 |
+
example_rows (list): Few-shot examples for the prompt.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
tuple: (system_prompt, user_prompt)
|
48 |
+
"""
|
49 |
+
return create_classification_prompt(
|
50 |
+
row=row,
|
51 |
+
label_descriptions=label_descriptions,
|
52 |
+
features=features,
|
53 |
+
example_rows=example_rows,
|
54 |
+
)
|
app/utils/tokens.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tiktoken
|
2 |
+
|
3 |
+
def estimate_token_count(prompt: str, model: str) -> int:
|
4 |
+
"""
|
5 |
+
Estimate the token count for a given prompt and model.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
prompt (str): The input prompt to tokenize.
|
9 |
+
model (str): The name of the model to use for token encoding.
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
int: The estimated token count.
|
13 |
+
"""
|
14 |
+
encoding = tiktoken.encoding_for_model(model)
|
15 |
+
return len(encoding.encode(prompt))
|
app/utils/validation.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, create_model
|
2 |
+
from typing import Literal, List
|
3 |
+
|
4 |
+
|
5 |
+
def generate_classification_model(labels: List[str]) -> BaseModel:
|
6 |
+
"""
|
7 |
+
Dynamically generates a Pydantic model for classification based on user-provided labels.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
labels (List[str]): List of valid label strings.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
BaseModel: A dynamically generated Pydantic model.
|
14 |
+
"""
|
15 |
+
return create_model(
|
16 |
+
"DynamicClassificationOutput",
|
17 |
+
label=(Literal[tuple(labels)], ...), # Enforce that 'label' matches one of the valid labels
|
18 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
kagglehub
|
2 |
+
pytest
|
3 |
+
pytest-mock
|
4 |
+
sentencepiece
|
5 |
+
sentence_transformers
|
6 |
+
streamlit
|
7 |
+
tiktoken
|
8 |
+
transformers
|
tests/__init__.py
ADDED
File without changes
|
tests/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (146 Bytes). View file
|
|
tests/__pycache__/test_api.cpython-310-pytest-8.3.4.pyc
ADDED
Binary file (1.18 kB). View file
|
|
tests/__pycache__/test_evaluation.cpython-310-pytest-8.3.4.pyc
ADDED
Binary file (1.21 kB). View file
|
|
tests/__pycache__/test_prompt.cpython-310-pytest-8.3.4.pyc
ADDED
Binary file (1.49 kB). View file
|
|
tests/__pycache__/test_validation.cpython-310-pytest-8.3.4.pyc
ADDED
Binary file (1.53 kB). View file
|
|
tests/test_api.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from unittest.mock import Mock
|
2 |
+
from utils.api import classify_row_chat
|
3 |
+
|
4 |
+
def test_classify_row_chat():
|
5 |
+
# Mock the OpenAI client and its response
|
6 |
+
client_mock = Mock()
|
7 |
+
client_mock.chat.completions.create.return_value = Mock(
|
8 |
+
choices=[Mock(message=Mock(content="Positive"))]
|
9 |
+
)
|
10 |
+
|
11 |
+
# Define the prompt
|
12 |
+
prompt = "Classify the following observation: Age: 25, Weight: 70\nLabel:"
|
13 |
+
|
14 |
+
# Call the classify_row_chat function with the mocked client
|
15 |
+
prediction = classify_row_chat(prompt=prompt, client=client_mock, model="gpt-3.5-turbo")
|
16 |
+
|
17 |
+
# Assert the response matches the expected label
|
18 |
+
assert prediction == "Positive", "The classification should return 'Positive'"
|
tests/test_evaluation.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.evaluation import evaluate_predictions
|
2 |
+
|
3 |
+
def test_evaluate_predictions():
|
4 |
+
y_true = ["Positive", "Negative", "Positive"]
|
5 |
+
y_pred = ["Positive", "Negative", "Positive"]
|
6 |
+
|
7 |
+
# Test perfect match
|
8 |
+
report = evaluate_predictions(y_true, y_pred)
|
9 |
+
assert report["accuracy"] == 1.0, "Accuracy should be 100% for perfect predictions"
|
10 |
+
|
11 |
+
# Test mismatched predictions
|
12 |
+
y_pred_mismatch = ["Negative", "Negative", "Positive"]
|
13 |
+
report_mismatch = evaluate_predictions(y_true, y_pred_mismatch)
|
14 |
+
assert report_mismatch["accuracy"] < 1.0, "Accuracy should be less than 100% for mismatched predictions"
|
tests/test_prompt.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from utils.prompt import generate_prompts
|
3 |
+
|
4 |
+
def test_generate_prompts():
|
5 |
+
example_rows = [ # Update to match the function's parameter name
|
6 |
+
{"features": {"Age": 34, "Weight": 70, "Location": "Urban"}, "label": "Positive"},
|
7 |
+
{"features": {"Age": 25, "Weight": 60, "Location": "Rural"}, "label": "Negative"},
|
8 |
+
]
|
9 |
+
features = ["Age", "Weight", "Location"]
|
10 |
+
label_descriptions = {
|
11 |
+
"Positive": "The sentiment is positive.",
|
12 |
+
"Negative": "The sentiment is negative.",
|
13 |
+
}
|
14 |
+
|
15 |
+
row = {"Age": 30, "Weight": 65, "Location": "Suburban"}
|
16 |
+
|
17 |
+
system_prompt, user_prompt = generate_prompts(
|
18 |
+
row=row, example_rows=example_rows, features=features, label_descriptions=label_descriptions
|
19 |
+
)
|
20 |
+
assert "Age: 34; Weight: 70; Location: Urban" in system_prompt
|
21 |
+
assert "Label: Positive" in system_prompt
|
22 |
+
assert "Label:" in user_prompt
|
23 |
+
|
tests/test_validation.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import ValidationError
|
2 |
+
from utils.validation import generate_classification_model
|
3 |
+
|
4 |
+
def test_classification_output_validation():
|
5 |
+
# Dynamically generate classification model
|
6 |
+
ClassificationOutput = generate_classification_model(["Positive", "Negative"])
|
7 |
+
|
8 |
+
# Test valid input
|
9 |
+
valid_output = ClassificationOutput(label="Positive")
|
10 |
+
assert valid_output.label == "Positive", "Label should be 'Positive'"
|
11 |
+
|
12 |
+
# Test invalid input
|
13 |
+
try:
|
14 |
+
ClassificationOutput(label="InvalidLabel")
|
15 |
+
except ValidationError as e:
|
16 |
+
error_message = str(e)
|
17 |
+
assert "Input should be 'Positive' or 'Negative'" in error_message, "Should raise validation error with correct message"
|