Spaces:
Sleeping
Sleeping
add endpoints
Browse files- app.py +19 -226
- classifiers/llm.py +39 -35
- process.py +103 -1
- test_server.py +2 -1
- utils.py +47 -1
app.py
CHANGED
@@ -9,14 +9,15 @@ import matplotlib.pyplot as plt
|
|
9 |
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
-
from process import update_api_key, process_file_async, export_results
|
13 |
from client import get_client, initialize_client
|
|
|
|
|
14 |
|
15 |
# Load environment variables from .env file
|
16 |
load_dotenv()
|
17 |
|
18 |
# Import local modules
|
19 |
-
from utils import load_data, visualize_results
|
20 |
from prompts import (
|
21 |
CATEGORY_SUGGESTION_PROMPT,
|
22 |
ADDITIONAL_CATEGORY_PROMPT,
|
@@ -147,7 +148,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
147 |
)
|
148 |
|
149 |
# Function to load file and suggest categories
|
150 |
-
def load_file_and_suggest_categories(file):
|
151 |
if not file:
|
152 |
return (
|
153 |
[],
|
@@ -167,67 +168,17 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
167 |
columns = list(df.columns)
|
168 |
|
169 |
# Analyze columns to suggest text columns
|
170 |
-
suggested_text_columns =
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
# Check if column contains mostly text (not just numbers or dates)
|
175 |
-
sample = df[col].head(100).dropna()
|
176 |
-
if len(sample) > 0:
|
177 |
-
# Check if most values contain spaces (indicating text)
|
178 |
-
text_ratio = sum(" " in str(val) for val in sample) / len(
|
179 |
-
sample
|
180 |
-
)
|
181 |
-
if (
|
182 |
-
text_ratio > 0.3
|
183 |
-
): # If more than 30% of values contain spaces
|
184 |
-
suggested_text_columns.append(col)
|
185 |
-
|
186 |
-
# If no columns were suggested, use all object columns
|
187 |
-
if not suggested_text_columns:
|
188 |
-
suggested_text_columns = [
|
189 |
-
col for col in columns if df[col].dtype == "object"
|
190 |
-
]
|
191 |
-
|
192 |
-
# Get a sample of text for category suggestion
|
193 |
-
sample_texts = []
|
194 |
-
for col in suggested_text_columns:
|
195 |
-
sample_texts.extend(df[col].head(5).tolist())
|
196 |
|
197 |
# Use LLM to suggest categories
|
198 |
if client:
|
199 |
-
|
200 |
-
|
201 |
-
)
|
202 |
-
try:
|
203 |
-
response = client.chat.completions.create(
|
204 |
-
model="gpt-3.5-turbo",
|
205 |
-
messages=[{"role": "user", "content": prompt}],
|
206 |
-
temperature=0,
|
207 |
-
max_tokens=100,
|
208 |
-
)
|
209 |
-
suggested_cats = [
|
210 |
-
cat.strip()
|
211 |
-
for cat in response.choices[0]
|
212 |
-
.message.content.strip()
|
213 |
-
.split(",")
|
214 |
-
]
|
215 |
-
except:
|
216 |
-
suggested_cats = [
|
217 |
-
"Positive",
|
218 |
-
"Negative",
|
219 |
-
"Neutral",
|
220 |
-
"Mixed",
|
221 |
-
"Other",
|
222 |
-
]
|
223 |
else:
|
224 |
-
suggested_cats = [
|
225 |
-
"Positive",
|
226 |
-
"Negative",
|
227 |
-
"Neutral",
|
228 |
-
"Mixed",
|
229 |
-
"Other",
|
230 |
-
]
|
231 |
|
232 |
return (
|
233 |
columns,
|
@@ -295,7 +246,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
295 |
)
|
296 |
|
297 |
# Function to suggest a new category
|
298 |
-
def suggest_new_category(file, current_categories, text_columns):
|
299 |
if not file or not text_columns:
|
300 |
return gr.CheckboxGroup(
|
301 |
choices=current_categories, value=current_categories
|
@@ -303,29 +254,16 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
303 |
|
304 |
try:
|
305 |
df = load_data(file.name)
|
306 |
-
|
307 |
-
# Get sample texts from selected columns
|
308 |
-
sample_texts = []
|
309 |
-
for col in text_columns:
|
310 |
-
sample_texts.extend(df[col].head(5).tolist())
|
311 |
|
312 |
if client:
|
313 |
-
|
314 |
-
|
315 |
-
sample_texts
|
|
|
|
|
|
|
316 |
)
|
317 |
-
try:
|
318 |
-
response = client.chat.completions.create(
|
319 |
-
model="gpt-3.5-turbo",
|
320 |
-
messages=[{"role": "user", "content": prompt}],
|
321 |
-
temperature=0,
|
322 |
-
max_tokens=50,
|
323 |
-
)
|
324 |
-
new_cat = response.choices[0].message.content.strip()
|
325 |
-
if new_cat and new_cat not in current_categories:
|
326 |
-
current_categories.append(new_cat)
|
327 |
-
except:
|
328 |
-
pass
|
329 |
|
330 |
return gr.CheckboxGroup(
|
331 |
choices=current_categories, value=current_categories
|
@@ -342,151 +280,6 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
342 |
file_path = export_results(df, format_type)
|
343 |
return gr.File(value=file_path, visible=True)
|
344 |
|
345 |
-
# Function to improve classification based on validation report
|
346 |
-
async def improve_classification_async(
|
347 |
-
df,
|
348 |
-
validation_report,
|
349 |
-
text_columns,
|
350 |
-
categories,
|
351 |
-
classifier_type,
|
352 |
-
show_explanations,
|
353 |
-
file,
|
354 |
-
):
|
355 |
-
"""Async version of improve_classification"""
|
356 |
-
if df is None or not validation_report:
|
357 |
-
return (
|
358 |
-
df,
|
359 |
-
validation_report,
|
360 |
-
gr.Button(visible=False),
|
361 |
-
gr.CheckboxGroup(choices=[], value=[]),
|
362 |
-
)
|
363 |
-
|
364 |
-
try:
|
365 |
-
# Extract insights from validation report
|
366 |
-
if client:
|
367 |
-
prompt = VALIDATION_ANALYSIS_PROMPT.format(
|
368 |
-
validation_report=validation_report,
|
369 |
-
current_categories=categories,
|
370 |
-
)
|
371 |
-
try:
|
372 |
-
response = client.chat.completions.create(
|
373 |
-
model="gpt-4",
|
374 |
-
messages=[{"role": "user", "content": prompt}],
|
375 |
-
temperature=0,
|
376 |
-
max_tokens=300,
|
377 |
-
)
|
378 |
-
improvements = json.loads(
|
379 |
-
response.choices[0].message.content.strip()
|
380 |
-
)
|
381 |
-
|
382 |
-
# Get current categories
|
383 |
-
current_categories = [
|
384 |
-
cat.strip() for cat in categories.split(",")
|
385 |
-
]
|
386 |
-
|
387 |
-
# If new categories are needed, suggest them based on the data
|
388 |
-
if improvements.get("new_categories_needed", False):
|
389 |
-
# Get sample texts for category suggestion
|
390 |
-
sample_texts = []
|
391 |
-
for col in text_columns:
|
392 |
-
if isinstance(file, str):
|
393 |
-
temp_df = load_data(file)
|
394 |
-
else:
|
395 |
-
temp_df = load_data(file.name)
|
396 |
-
sample_texts.extend(temp_df[col].head(10).tolist())
|
397 |
-
|
398 |
-
category_prompt = CATEGORY_IMPROVEMENT_PROMPT.format(
|
399 |
-
current_categories=", ".join(current_categories),
|
400 |
-
analysis=improvements.get("analysis", ""),
|
401 |
-
sample_texts="\n---\n".join(sample_texts[:10]),
|
402 |
-
)
|
403 |
-
|
404 |
-
category_response = client.chat.completions.create(
|
405 |
-
model="gpt-4",
|
406 |
-
messages=[{"role": "user", "content": category_prompt}],
|
407 |
-
temperature=0,
|
408 |
-
max_tokens=100,
|
409 |
-
)
|
410 |
-
|
411 |
-
new_categories = [
|
412 |
-
cat.strip()
|
413 |
-
for cat in category_response.choices[0]
|
414 |
-
.message.content.strip()
|
415 |
-
.split(",")
|
416 |
-
]
|
417 |
-
# Combine current and new categories
|
418 |
-
all_categories = current_categories + new_categories
|
419 |
-
categories = ",".join(all_categories)
|
420 |
-
|
421 |
-
# Process with improved parameters
|
422 |
-
improved_df, new_validation = await process_file_async(
|
423 |
-
file,
|
424 |
-
text_columns,
|
425 |
-
categories,
|
426 |
-
classifier_type,
|
427 |
-
show_explanations,
|
428 |
-
)
|
429 |
-
|
430 |
-
return (
|
431 |
-
improved_df,
|
432 |
-
new_validation,
|
433 |
-
gr.Button(visible=True),
|
434 |
-
gr.CheckboxGroup(
|
435 |
-
choices=all_categories, value=all_categories
|
436 |
-
),
|
437 |
-
)
|
438 |
-
except Exception as e:
|
439 |
-
print(f"Error in improvement process: {str(e)}")
|
440 |
-
return (
|
441 |
-
df,
|
442 |
-
validation_report,
|
443 |
-
gr.Button(visible=True),
|
444 |
-
gr.CheckboxGroup(
|
445 |
-
choices=current_categories, value=current_categories
|
446 |
-
),
|
447 |
-
)
|
448 |
-
else:
|
449 |
-
return (
|
450 |
-
df,
|
451 |
-
validation_report,
|
452 |
-
gr.Button(visible=True),
|
453 |
-
gr.CheckboxGroup(
|
454 |
-
choices=current_categories, value=current_categories
|
455 |
-
),
|
456 |
-
)
|
457 |
-
except Exception as e:
|
458 |
-
print(f"Error in improvement process: {str(e)}")
|
459 |
-
return (
|
460 |
-
df,
|
461 |
-
validation_report,
|
462 |
-
gr.Button(visible=True),
|
463 |
-
gr.CheckboxGroup(
|
464 |
-
choices=current_categories, value=current_categories
|
465 |
-
),
|
466 |
-
)
|
467 |
-
|
468 |
-
def improve_classification(
|
469 |
-
df,
|
470 |
-
validation_report,
|
471 |
-
text_columns,
|
472 |
-
categories,
|
473 |
-
classifier_type,
|
474 |
-
show_explanations,
|
475 |
-
file,
|
476 |
-
):
|
477 |
-
"""Synchronous wrapper for improve_classification_async"""
|
478 |
-
return asyncio.run(
|
479 |
-
improve_classification_async(
|
480 |
-
df,
|
481 |
-
validation_report,
|
482 |
-
text_columns,
|
483 |
-
categories,
|
484 |
-
classifier_type,
|
485 |
-
show_explanations,
|
486 |
-
file,
|
487 |
-
)
|
488 |
-
)
|
489 |
-
|
490 |
# Connect functions
|
491 |
load_categories_button.click(
|
492 |
load_file_and_suggest_categories,
|
|
|
9 |
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
+
from process import update_api_key, process_file_async, export_results, improve_classification
|
13 |
from client import get_client, initialize_client
|
14 |
+
from utils import load_data, visualize_results, analyze_text_columns, get_sample_texts
|
15 |
+
from classifiers.llm import LLMClassifier
|
16 |
|
17 |
# Load environment variables from .env file
|
18 |
load_dotenv()
|
19 |
|
20 |
# Import local modules
|
|
|
21 |
from prompts import (
|
22 |
CATEGORY_SUGGESTION_PROMPT,
|
23 |
ADDITIONAL_CATEGORY_PROMPT,
|
|
|
148 |
)
|
149 |
|
150 |
# Function to load file and suggest categories
|
151 |
+
async def load_file_and_suggest_categories(file):
|
152 |
if not file:
|
153 |
return (
|
154 |
[],
|
|
|
168 |
columns = list(df.columns)
|
169 |
|
170 |
# Analyze columns to suggest text columns
|
171 |
+
suggested_text_columns = analyze_text_columns(df)
|
172 |
+
|
173 |
+
# Get sample texts for category suggestion
|
174 |
+
sample_texts = get_sample_texts(df, suggested_text_columns)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
# Use LLM to suggest categories
|
177 |
if client:
|
178 |
+
classifier = LLMClassifier(client=client)
|
179 |
+
suggested_cats = await classifier.suggest_categories_from_texts(sample_texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
else:
|
181 |
+
suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"]
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
return (
|
184 |
columns,
|
|
|
246 |
)
|
247 |
|
248 |
# Function to suggest a new category
|
249 |
+
async def suggest_new_category(file, current_categories, text_columns):
|
250 |
if not file or not text_columns:
|
251 |
return gr.CheckboxGroup(
|
252 |
choices=current_categories, value=current_categories
|
|
|
254 |
|
255 |
try:
|
256 |
df = load_data(file.name)
|
257 |
+
sample_texts = get_sample_texts(df, text_columns)
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if client:
|
260 |
+
classifier = LLMClassifier(client=client)
|
261 |
+
new_categories = await classifier.suggest_categories_from_texts(
|
262 |
+
sample_texts, current_categories
|
263 |
+
)
|
264 |
+
return gr.CheckboxGroup(
|
265 |
+
choices=new_categories, value=new_categories
|
266 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
return gr.CheckboxGroup(
|
269 |
choices=current_categories, value=current_categories
|
|
|
280 |
file_path = export_results(df, format_type)
|
281 |
return gr.File(value=file_path, visible=True)
|
282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
# Connect functions
|
284 |
load_categories_button.click(
|
285 |
load_file_and_suggest_categories,
|
classifiers/llm.py
CHANGED
@@ -6,14 +6,14 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
6 |
import random
|
7 |
import json
|
8 |
import asyncio
|
9 |
-
from typing import List, Dict, Any, Optional, Union
|
10 |
import sys
|
11 |
import os
|
12 |
from litellm import OpenAI
|
13 |
|
14 |
# Add the project root to the Python path
|
15 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
16 |
-
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
|
17 |
|
18 |
from .base import BaseClassifier
|
19 |
|
@@ -26,6 +26,43 @@ class LLMClassifier(BaseClassifier):
|
|
26 |
self.client: OpenAI = client
|
27 |
self.model: str = model
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
|
30 |
"""Async version of text classification"""
|
31 |
prompt: str = TEXT_CLASSIFICATION_PROMPT.format(
|
@@ -87,39 +124,6 @@ class LLMClassifier(BaseClassifier):
|
|
87 |
"explanation": f"Error during classification: {str(e)}",
|
88 |
}
|
89 |
|
90 |
-
async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
|
91 |
-
"""Async version of category suggestion"""
|
92 |
-
# Take a sample of texts to avoid token limitations
|
93 |
-
if len(texts) > sample_size:
|
94 |
-
sample_texts: List[str] = random.sample(texts, sample_size)
|
95 |
-
else:
|
96 |
-
sample_texts: List[str] = texts
|
97 |
-
|
98 |
-
prompt: str = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
|
99 |
-
|
100 |
-
try:
|
101 |
-
# Use the synchronous client method but run it in a thread pool
|
102 |
-
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
103 |
-
response: Any = await loop.run_in_executor(
|
104 |
-
None,
|
105 |
-
lambda: self.client.chat.completions.create(
|
106 |
-
model=self.model,
|
107 |
-
messages=[{"role": "user", "content": prompt}],
|
108 |
-
temperature=0.2,
|
109 |
-
max_tokens=100,
|
110 |
-
)
|
111 |
-
)
|
112 |
-
|
113 |
-
# Parse response to get categories
|
114 |
-
categories_text: str = response.choices[0].message.content.strip()
|
115 |
-
categories: List[str] = [cat.strip() for cat in categories_text.split(",")]
|
116 |
-
|
117 |
-
return categories
|
118 |
-
except Exception as e:
|
119 |
-
# Fallback to default categories on error
|
120 |
-
print(f"Error suggesting categories: {str(e)}")
|
121 |
-
return self._generate_default_categories(texts)
|
122 |
-
|
123 |
async def classify_async(
|
124 |
self, texts: List[str], categories: Optional[List[str]] = None
|
125 |
) -> List[Dict[str, Any]]:
|
|
|
6 |
import random
|
7 |
import json
|
8 |
import asyncio
|
9 |
+
from typing import List, Dict, Any, Optional, Union, Tuple
|
10 |
import sys
|
11 |
import os
|
12 |
from litellm import OpenAI
|
13 |
|
14 |
# Add the project root to the Python path
|
15 |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
16 |
+
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT, ADDITIONAL_CATEGORY_PROMPT
|
17 |
|
18 |
from .base import BaseClassifier
|
19 |
|
|
|
26 |
self.client: OpenAI = client
|
27 |
self.model: str = model
|
28 |
|
29 |
+
async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
|
30 |
+
"""Async version of category suggestion"""
|
31 |
+
# Take a sample of texts to avoid token limitations
|
32 |
+
if len(texts) > sample_size:
|
33 |
+
sample_texts: List[str] = random.sample(texts, sample_size)
|
34 |
+
else:
|
35 |
+
sample_texts: List[str] = texts
|
36 |
+
|
37 |
+
prompt: str = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Use the synchronous client method but run it in a thread pool
|
41 |
+
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
42 |
+
response: Any = await loop.run_in_executor(
|
43 |
+
None,
|
44 |
+
lambda: self.client.chat.completions.create(
|
45 |
+
model=self.model,
|
46 |
+
messages=[{"role": "user", "content": prompt}],
|
47 |
+
temperature=0.2,
|
48 |
+
max_tokens=100,
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
# Parse response to get categories
|
53 |
+
categories_text: str = response.choices[0].message.content.strip()
|
54 |
+
categories: List[str] = [cat.strip() for cat in categories_text.split(",")]
|
55 |
+
|
56 |
+
return categories
|
57 |
+
except Exception as e:
|
58 |
+
# Fallback to default categories on error
|
59 |
+
print(f"Error suggesting categories: {str(e)}")
|
60 |
+
return self._generate_default_categories(texts)
|
61 |
+
|
62 |
+
def _generate_default_categories(self, texts: List[str]) -> List[str]:
|
63 |
+
"""Generate default categories if LLM suggestion fails"""
|
64 |
+
return ["Positive", "Negative", "Neutral", "Mixed", "Other"]
|
65 |
+
|
66 |
async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
|
67 |
"""Async version of text classification"""
|
68 |
prompt: str = TEXT_CLASSIFICATION_PROMPT.format(
|
|
|
124 |
"explanation": f"Error during classification: {str(e)}",
|
125 |
}
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
async def classify_async(
|
128 |
self, texts: List[str], categories: Optional[List[str]] = None
|
129 |
) -> List[Dict[str, Any]]:
|
process.py
CHANGED
@@ -6,10 +6,12 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|
6 |
from typing import Optional, List, Dict, Any, Tuple, Union
|
7 |
import pandas as pd
|
8 |
from pathlib import Path
|
|
|
9 |
|
10 |
from classifiers import TFIDFClassifier, LLMClassifier
|
11 |
-
from utils import load_data, validate_results
|
12 |
from client import get_client
|
|
|
13 |
|
14 |
|
15 |
def update_api_key(api_key: str) -> Tuple[bool, str]:
|
@@ -174,3 +176,103 @@ def export_results(df: pd.DataFrame, format_type: str) -> Optional[str]:
|
|
174 |
df.to_csv(file_path, index=False)
|
175 |
|
176 |
return file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from typing import Optional, List, Dict, Any, Tuple, Union
|
7 |
import pandas as pd
|
8 |
from pathlib import Path
|
9 |
+
import json
|
10 |
|
11 |
from classifiers import TFIDFClassifier, LLMClassifier
|
12 |
+
from utils import load_data, validate_results, get_sample_texts
|
13 |
from client import get_client
|
14 |
+
from prompts import VALIDATION_ANALYSIS_PROMPT, CATEGORY_IMPROVEMENT_PROMPT
|
15 |
|
16 |
|
17 |
def update_api_key(api_key: str) -> Tuple[bool, str]:
|
|
|
176 |
df.to_csv(file_path, index=False)
|
177 |
|
178 |
return file_path
|
179 |
+
|
180 |
+
|
181 |
+
async def improve_classification(
|
182 |
+
df: pd.DataFrame,
|
183 |
+
validation_report: str,
|
184 |
+
text_columns: List[str],
|
185 |
+
categories: str,
|
186 |
+
classifier_type: str,
|
187 |
+
show_explanations: bool,
|
188 |
+
file: Union[str, Path]
|
189 |
+
) -> Tuple[Optional[pd.DataFrame], Optional[str], bool, List[str]]:
|
190 |
+
"""
|
191 |
+
Improve classification based on validation report
|
192 |
+
|
193 |
+
Args:
|
194 |
+
df (pd.DataFrame): Current classification results
|
195 |
+
validation_report (str): Validation report from previous classification
|
196 |
+
text_columns (List[str]): List of text column names
|
197 |
+
categories (str): Comma-separated list of categories
|
198 |
+
classifier_type (str): Type of classifier to use
|
199 |
+
show_explanations (bool): Whether to show explanations
|
200 |
+
file (Union[str, Path]): Path to the input file
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tuple[Optional[pd.DataFrame], Optional[str], bool, List[str]]:
|
204 |
+
- Improved dataframe
|
205 |
+
- New validation report
|
206 |
+
- Whether improvement was successful
|
207 |
+
- Updated categories
|
208 |
+
"""
|
209 |
+
if df is None or not validation_report:
|
210 |
+
return None, validation_report, False, []
|
211 |
+
|
212 |
+
try:
|
213 |
+
client = get_client()
|
214 |
+
if not client:
|
215 |
+
return None, "Error: API client not initialized", False, []
|
216 |
+
|
217 |
+
# Extract insights from validation report
|
218 |
+
prompt = VALIDATION_ANALYSIS_PROMPT.format(
|
219 |
+
validation_report=validation_report,
|
220 |
+
current_categories=categories,
|
221 |
+
)
|
222 |
+
|
223 |
+
response = await asyncio.get_event_loop().run_in_executor(
|
224 |
+
None,
|
225 |
+
lambda: client.chat.completions.create(
|
226 |
+
model="gpt-4",
|
227 |
+
messages=[{"role": "user", "content": prompt}],
|
228 |
+
temperature=0,
|
229 |
+
max_tokens=300,
|
230 |
+
)
|
231 |
+
)
|
232 |
+
|
233 |
+
improvements = json.loads(response.choices[0].message.content.strip())
|
234 |
+
current_categories = [cat.strip() for cat in categories.split(",")]
|
235 |
+
|
236 |
+
# If new categories are needed, suggest them based on the data
|
237 |
+
if improvements.get("new_categories_needed", False):
|
238 |
+
# Get sample texts for category suggestion
|
239 |
+
sample_texts = get_sample_texts(df, text_columns, sample_size=10)
|
240 |
+
|
241 |
+
category_prompt = CATEGORY_IMPROVEMENT_PROMPT.format(
|
242 |
+
current_categories=", ".join(current_categories),
|
243 |
+
analysis=improvements.get("analysis", ""),
|
244 |
+
sample_texts="\n---\n".join(sample_texts)
|
245 |
+
)
|
246 |
+
|
247 |
+
category_response = await asyncio.get_event_loop().run_in_executor(
|
248 |
+
None,
|
249 |
+
lambda: client.chat.completions.create(
|
250 |
+
model="gpt-4",
|
251 |
+
messages=[{"role": "user", "content": category_prompt}],
|
252 |
+
temperature=0,
|
253 |
+
max_tokens=100,
|
254 |
+
)
|
255 |
+
)
|
256 |
+
|
257 |
+
new_categories = [
|
258 |
+
cat.strip()
|
259 |
+
for cat in category_response.choices[0].message.content.strip().split(",")
|
260 |
+
]
|
261 |
+
# Combine current and new categories
|
262 |
+
all_categories = current_categories + new_categories
|
263 |
+
categories = ",".join(all_categories)
|
264 |
+
|
265 |
+
# Process with improved parameters
|
266 |
+
improved_df, new_validation = await process_file_async(
|
267 |
+
file,
|
268 |
+
text_columns,
|
269 |
+
categories,
|
270 |
+
classifier_type,
|
271 |
+
show_explanations,
|
272 |
+
)
|
273 |
+
|
274 |
+
return improved_df, new_validation, True, all_categories if improvements.get("new_categories_needed", False) else current_categories
|
275 |
+
|
276 |
+
except Exception as e:
|
277 |
+
print(f"Error in improvement process: {str(e)}")
|
278 |
+
return df, validation_report, False, current_categories
|
test_server.py
CHANGED
@@ -107,7 +107,8 @@ def test_validate_classifications() -> None:
|
|
107 |
f"{BASE_URL}/suggest-categories",
|
108 |
json=[email["contenu"] for email in emails[:5]]
|
109 |
)
|
110 |
-
|
|
|
111 |
|
112 |
# Send validation request
|
113 |
validation_request: Dict[str, Any] = {
|
|
|
107 |
f"{BASE_URL}/suggest-categories",
|
108 |
json=[email["contenu"] for email in emails[:5]]
|
109 |
)
|
110 |
+
response_data: Dict[str, Any] = categories_response.json()
|
111 |
+
current_categories: List[str] = response_data["categories"] # Extract categories from the response
|
112 |
|
113 |
# Send validation request
|
114 |
validation_request: Dict[str, Any] = {
|
utils.py
CHANGED
@@ -6,7 +6,7 @@ from sklearn.decomposition import PCA
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
import tempfile
|
8 |
from prompts import VALIDATION_PROMPT
|
9 |
-
from typing import List, Optional, Any, Union
|
10 |
from pathlib import Path
|
11 |
from matplotlib.figure import Figure
|
12 |
|
@@ -33,6 +33,52 @@ def load_data(file_path: Union[str, Path]) -> pd.DataFrame:
|
|
33 |
)
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def export_data(df: pd.DataFrame, file_name: str, format_type: str = "excel") -> str:
|
37 |
"""
|
38 |
Export dataframe to file
|
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
import tempfile
|
8 |
from prompts import VALIDATION_PROMPT
|
9 |
+
from typing import List, Optional, Any, Union, Tuple
|
10 |
from pathlib import Path
|
11 |
from matplotlib.figure import Figure
|
12 |
|
|
|
33 |
)
|
34 |
|
35 |
|
36 |
+
def analyze_text_columns(df: pd.DataFrame) -> List[str]:
|
37 |
+
"""
|
38 |
+
Analyze columns to suggest text columns based on content analysis
|
39 |
+
|
40 |
+
Args:
|
41 |
+
df (pd.DataFrame): Input dataframe
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
List[str]: List of suggested text columns
|
45 |
+
"""
|
46 |
+
suggested_text_columns: List[str] = []
|
47 |
+
for col in df.columns:
|
48 |
+
if df[col].dtype == "object": # String type
|
49 |
+
# Check if column contains mostly text (not just numbers or dates)
|
50 |
+
sample = df[col].head(100).dropna()
|
51 |
+
if len(sample) > 0:
|
52 |
+
# Check if most values contain spaces (indicating text)
|
53 |
+
text_ratio = sum(" " in str(val) for val in sample) / len(sample)
|
54 |
+
if text_ratio > 0.3: # If more than 30% of values contain spaces
|
55 |
+
suggested_text_columns.append(col)
|
56 |
+
|
57 |
+
# If no columns were suggested, use all object columns
|
58 |
+
if not suggested_text_columns:
|
59 |
+
suggested_text_columns = [col for col in df.columns if df[col].dtype == "object"]
|
60 |
+
|
61 |
+
return suggested_text_columns
|
62 |
+
|
63 |
+
|
64 |
+
def get_sample_texts(df: pd.DataFrame, text_columns: List[str], sample_size: int = 5) -> List[str]:
|
65 |
+
"""
|
66 |
+
Get sample texts from specified columns
|
67 |
+
|
68 |
+
Args:
|
69 |
+
df (pd.DataFrame): Input dataframe
|
70 |
+
text_columns (List[str]): List of text column names
|
71 |
+
sample_size (int): Number of samples to take from each column
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
List[str]: List of sample texts
|
75 |
+
"""
|
76 |
+
sample_texts: List[str] = []
|
77 |
+
for col in text_columns:
|
78 |
+
sample_texts.extend(df[col].head(sample_size).tolist())
|
79 |
+
return sample_texts
|
80 |
+
|
81 |
+
|
82 |
def export_data(df: pd.DataFrame, file_name: str, format_type: str = "excel") -> str:
|
83 |
"""
|
84 |
Export dataframe to file
|